Oxford Nanopore Sequencing Benchmark¶
This report presents a benchmark of SNVs, indels, and SVs, and a characterisation of the ONT dataset used for this benchmark.
Methods¶
Data Processing¶
Basecalling: wf-basecalling v1.1.7
Alignment and Variant Calling: wf-human-variation v2.1.0
Quality Control Tools¶
NanoPlot: 1.42.0
- Generates summary statistics for each sample and creates visualizations of QC metrics for sequencing summaries and aligned BAM files.
NanoComp: 1.23.1
- Compares multiple sequencing runs and generates comparative plots.
mosdepth: 0.3.3
- Calculates sequencing depth across the human genome for each sample.
rtg-tools: 3.12.1
- Performs performs variant comparison against a truth dataset.
SURVIVOR: 1.0.7
- Performs merging of vcf files to compare SVs within a sample and among populations/samples.
In [1]:
# Standard library imports
import os
import glob
import gzip
import pickle
import logging
import re
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from statistics import mean
from typing import Any, DefaultDict, Dict, List, Literal, Optional, Set, Tuple, Union
# Third-party imports
import numpy as np
import polars as pl
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import seaborn as sns
import statsmodels.api as sm
import pysam
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from scipy import stats
from scipy.optimize import curve_fit
from scipy.spatial.distance import cdist
from sklearn import metrics
from statsmodels.stats.multitest import multipletests
# Seaborn settings
sns.set_style("whitegrid")
sns.set_context("paper")
sns.set_palette("colorblind")
# Logging settings
logging.basicConfig(
level=logging.INFO,
format="%(name)s - %(levelname)s - %(message)s",
force=True,
)
logger = logging.getLogger(__name__)
Sequencing Quality Control¶
Aggregate table of the QC metrics from NanoStats for both singleplexed and multiplexed samples, from the aligned .cram files produced by wf-human-variation.
Unless otherwise specified, subsequent plots and statistics include only samples basecalled with the sup algorithm.
In [2]:
def get_default_column_types() -> Dict[str, str]:
"""
Define the data types for the columns in the DataFrame.
Returns:
Dict[str, str]: Mapping of column names to their types
"""
return {
"multiplexing": "category",
"basecall": "category",
"anonymised_sample": "category",
"number_of_reads": "numeric",
"number_of_bases": "numeric",
"number_of_bases_aligned": "numeric",
"fraction_bases_aligned": "numeric",
"mean_read_length": "numeric",
"median_read_length": "numeric",
"read_length_stdev": "numeric",
"n50": "numeric",
"mean_qual": "numeric",
"median_qual": "numeric",
"average_identity": "numeric",
"Reads >Q5_percentage": "numeric",
"Reads >Q7_percentage": "numeric",
"Reads >Q10_percentage": "numeric",
"Reads >Q12_percentage": "numeric",
"Reads >Q15_percentage": "numeric",
}
@dataclass
class NanoStatsConfig:
"""Configuration for NanoStats parsing with default settings."""
skip_categories: tuple = ("longest_read_(with_Q)", "highest_Q_read_(with_length)")
column_types: Dict[str, str] = field(default_factory=get_default_column_types)
required_metrics: tuple = (
"number_of_reads",
"number_of_bases",
"median_read_length",
"mean_read_length",
"read_length_stdev",
"n50",
"mean_qual",
"median_qual",
"Reads_>Q5",
"Reads_>Q7",
"Reads_>Q10",
"Reads_>Q12",
"Reads_>Q15",
)
def _parse_nanostats_file(file_path: Path) -> Dict[str, float]:
"""
Parse a NanoStats.txt file and extract metrics.
Args:
file_path (Path): Path to NanoStats.txt file
Returns:
Dict[str, float]: Dictionary of metrics and their values
Raises:
ValueError: If required metrics are missing from the file
"""
metrics = {}
try:
with open(file_path) as f:
next(f) # Skip header line
for line in f:
key, value = line.strip().split("\t")
key = key.strip(":")
if any(skip in key for skip in NanoStatsConfig.skip_categories):
continue
if key.startswith("Reads >Q"):
match = re.search(r"\((\d+\.\d+)%\)", value)
if match:
clean_key = key.replace(" ", "_")
metrics[clean_key] = float(match.group(1)) / 100
continue
try:
clean_value = value.split()[0].replace(",", "")
metrics[key.lower().replace(" ", "_")] = float(clean_value)
except (ValueError, IndexError):
logger.warning(
f"Could not parse value for metric {key} in {file_path}"
)
# Verify required metrics
missing_metrics = [
metric
for metric in NanoStatsConfig.required_metrics
if metric not in metrics
]
if missing_metrics:
raise ValueError(f"Missing required metrics: {missing_metrics}")
return metrics
except FileNotFoundError:
logger.error(f"NanoStats file not found: {file_path}")
raise
except Exception as e:
logger.error(f"Error parsing NanoStats file {file_path}: {str(e)}")
raise
def _get_multiplexing_status(seq_summaries_dir: Path, sample_id: str) -> str:
"""
Determine if a sample is multiplexed.
Args:
seq_summaries_dir (Path): Directory containing sequencing summaries
sample_id (str): Sample identifier
Returns:
str: 'multiplex' or 'singleplex'
"""
try:
for dir_path in seq_summaries_dir.glob("*"):
if not dir_path.is_dir():
continue
samples = dir_path.name.split("__")
if sample_id in samples:
return "multiplex" if len(samples) > 1 else "singleplex"
return "singleplex"
except Exception as e:
logger.error(f"Error determining multiplexing status for {sample_id}: {str(e)}")
raise
def _extract_sample_info(dir_path: Path) -> Tuple[str, str]:
"""
Extract sample ID and basecall type from directory path.
Args:
dir_path (Path): Directory path containing sample information
Returns:
Tuple[str, str]: Tuple of (sample_id, basecall_type)
Raises:
ValueError: If directory name format is invalid
"""
try:
parts = dir_path.name.split("_")
if len(parts) < 2:
raise ValueError(f"Invalid directory name format: {dir_path.name}")
sample_id = "_".join(parts[:-1])
basecall = parts[-1]
return sample_id, basecall
except Exception as e:
logger.error(f"Error extracting sample info from {dir_path}: {str(e)}")
raise
def parse_nanostats(
aligned_bams_dir: Path,
seq_summaries_dir: Path,
) -> pl.DataFrame:
"""
Parse NanoStats files and create a DataFrame with QC metrics.
Args:
aligned_bams_dir (Path): Directory containing aligned BAM files
seq_summaries_dir (Path): Directory containing sequencing summaries
Returns:
pl.DataFrame: Polars DataFrame containing parsed metrics, sorted by sample ID
Raises:
FileNotFoundError: If input directories don't exist
ValueError: If no valid samples are found
"""
if not aligned_bams_dir.exists():
raise FileNotFoundError(f"Aligned BAMs directory not found: {aligned_bams_dir}")
if not seq_summaries_dir.exists():
raise FileNotFoundError(
f"Sequencing summaries directory not found: {seq_summaries_dir}"
)
data: List[Dict] = []
sample_ids = set()
# Collect all sample IDs
for dir_path in aligned_bams_dir.glob("*"):
if not dir_path.is_dir():
continue
sample_id, _ = _extract_sample_info(dir_path)
sample_ids.add(sample_id)
if not sample_ids:
raise ValueError("No valid samples found in the input directory")
# Create anonymised sample mapping
sample_mapping = {
sample_id: f"Sample {i+1}" for i, sample_id in enumerate(sorted(sample_ids))
}
# Process each sample directory
for dir_path in aligned_bams_dir.glob("*"):
if not dir_path.is_dir():
continue
try:
nanostats_file = dir_path / "NanoStats.txt"
sample_id, basecall = _extract_sample_info(dir_path)
metrics = _parse_nanostats_file(nanostats_file)
sample_data = {
"sample": sample_id,
"anonymised_sample": sample_mapping[sample_id],
"basecall": basecall,
"multiplexing": _get_multiplexing_status(seq_summaries_dir, sample_id),
**metrics,
}
data.append(sample_data)
except Exception as e:
logger.error(f"Error processing directory {dir_path}: {str(e)}")
continue
if not data:
raise ValueError("No valid data could be parsed from any sample")
df = pl.DataFrame(data)
return df.sort(["sample", "anonymised_sample"])
np_seq_summaries_dir = Path(
"/scratch/prj/ppn_als_longread/ont-benchmark/qc/nanoplot/seq_summaries/"
)
np_aligned_bams_dir = Path(
"/scratch/prj/ppn_als_longread/ont-benchmark/qc/nanoplot/aligned_bams/"
)
nanoplot_qc_metrics_df = parse_nanostats(
aligned_bams_dir=np_aligned_bams_dir,
seq_summaries_dir=np_seq_summaries_dir,
)
logger.info(f"Successfully processed {len(nanoplot_qc_metrics_df)} samples")
with pl.Config(tbl_rows=len(nanoplot_qc_metrics_df)):
display(nanoplot_qc_metrics_df)
__main__ - INFO - Successfully processed 14 samples
shape: (14, 21)
| sample | anonymised_sample | basecall | multiplexing | number_of_reads | number_of_bases | number_of_bases_aligned | fraction_bases_aligned | median_read_length | mean_read_length | read_length_stdev | n50 | average_identity | median_identity | mean_qual | median_qual | Reads_>Q5 | Reads_>Q7 | Reads_>Q10 | Reads_>Q12 | Reads_>Q15 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| str | str | str | str | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 |
| "A046_12" | "Sample 1" | "sup" | "multiplex" | 5.063989e6 | 3.3558e10 | 3.1483e10 | 0.9 | 4600.0 | 6626.8 | 6582.9 | 11608.0 | 97.5 | 98.9 | 17.9 | 20.9 | 0.999 | 0.998 | 0.993 | 0.977 | 0.892 |
| "A048_09" | "Sample 2" | "sup" | "multiplex" | 7.088011e6 | 4.0078e10 | 3.7407e10 | 0.9 | 2996.0 | 5654.3 | 6755.6 | 11981.0 | 97.5 | 98.8 | 17.6 | 20.3 | 0.999 | 0.999 | 0.994 | 0.978 | 0.883 |
| "A079_07" | "Sample 3" | "sup" | "multiplex" | 3.813948e6 | 3.1232e10 | 2.9834e10 | 1.0 | 5219.0 | 8189.0 | 8440.0 | 15713.0 | 97.4 | 98.8 | 17.8 | 20.5 | 0.999 | 0.998 | 0.993 | 0.978 | 0.892 |
| "A081_91" | "Sample 4" | "sup" | "multiplex" | 3.278883e6 | 2.5565e10 | 2.4259e10 | 0.9 | 4066.0 | 7797.0 | 9011.5 | 16853.0 | 97.4 | 98.8 | 17.7 | 20.4 | 0.999 | 0.998 | 0.993 | 0.979 | 0.89 |
| "A085_00" | "Sample 5" | "sup" | "multiplex" | 3.767749e6 | 2.7359e10 | 2.5779e10 | 0.9 | 4120.0 | 7261.4 | 8121.2 | 15061.0 | 97.2 | 98.7 | 17.1 | 19.8 | 0.999 | 0.998 | 0.991 | 0.969 | 0.857 |
| "A097_92" | "Sample 6" | "sup" | "multiplex" | 4.264823e6 | 3.5497e10 | 3.3586e10 | 0.9 | 6204.0 | 8323.1 | 7593.4 | 14429.0 | 97.1 | 98.7 | 17.3 | 20.1 | 0.999 | 0.998 | 0.991 | 0.973 | 0.87 |
| "A149_01" | "Sample 7" | "sup" | "singleplex" | 8.228301e6 | 4.7190e10 | 4.4520e10 | 0.9 | 3532.0 | 5735.1 | 6354.1 | 10895.0 | 97.4 | 98.8 | 17.6 | 20.3 | 0.999 | 0.999 | 0.994 | 0.976 | 0.877 |
| "A153_01" | "Sample 8" | "sup" | "singleplex" | 7.346662e6 | 5.0667e10 | 4.8132e10 | 0.9 | 6660.0 | 6896.7 | 5705.6 | 10074.0 | 97.4 | 98.9 | 18.3 | 21.4 | 1.0 | 0.999 | 0.995 | 0.983 | 0.905 |
| "A153_06" | "Sample 9" | "sup" | "singleplex" | 1.1559255e7 | 7.8039e10 | 7.3883e10 | 0.9 | 5641.0 | 6751.2 | 6382.8 | 10973.0 | 97.4 | 98.9 | 17.8 | 20.7 | 0.999 | 0.998 | 0.993 | 0.978 | 0.892 |
| "A154_04" | "Sample 10" | "sup" | "singleplex" | 9.270031e6 | 5.4649e10 | 5.1180e10 | 0.9 | 3996.0 | 5895.2 | 6062.2 | 10939.0 | 97.3 | 98.8 | 17.6 | 20.6 | 0.999 | 0.998 | 0.992 | 0.975 | 0.879 |
| "A154_06" | "Sample 11" | "sup" | "singleplex" | 8.338801e6 | 5.7601e10 | 5.4980e10 | 1.0 | 6073.0 | 6907.6 | 6178.5 | 11330.0 | 97.3 | 98.7 | 17.6 | 20.2 | 1.0 | 0.999 | 0.995 | 0.98 | 0.882 |
| "A157_02" | "Sample 12" | "sup" | "singleplex" | 8.250276e6 | 5.3573e10 | 5.1193e10 | 1.0 | 6075.0 | 6493.5 | 5350.5 | 9628.0 | 97.3 | 98.8 | 18.0 | 20.9 | 1.0 | 0.999 | 0.995 | 0.981 | 0.893 |
| "A160_96" | "Sample 13" | "sup" | "singleplex" | 1.1591344e7 | 8.3603e10 | 8.0174e10 | 1.0 | 6601.0 | 7212.5 | 6210.7 | 11126.0 | 97.4 | 98.8 | 17.8 | 20.6 | 1.0 | 0.999 | 0.995 | 0.981 | 0.884 |
| "A162_09" | "Sample 14" | "sup" | "singleplex" | 1.487444e7 | 8.9966e10 | 8.5759e10 | 1.0 | 3957.0 | 6048.3 | 6378.4 | 11325.0 | 97.5 | 98.8 | 17.7 | 20.3 | 1.0 | 0.999 | 0.995 | 0.979 | 0.878 |
In [3]:
def _create_yield_plot(
ax: plt.Axes,
data: pl.DataFrame,
x: str,
y: str,
hue: str,
title: str,
xlabel: str,
ylabel: str,
) -> None:
"""
Create a bar plot showing yield metrics.
Args:
ax (plt.Axes): Matplotlib axes object to plot on
data (pl.DataFrame): Polars DataFrame containing the data
x (str): Column name for x-axis
y (str): Column name for y-axis
hue (str): Column name for color grouping
title (str): Plot title
xlabel (str): X-axis label
ylabel (str): Y-axis label
Raises:
ValueError: If required columns are not found in the DataFrame
"""
try:
# Validate input columns
required_cols = {x, y, hue}
if not required_cols.issubset(data.columns):
missing = required_cols - set(data.columns)
raise ValueError(f"Missing required columns: {missing}")
# Create plot with sorted data
sns.barplot(x=x, y=y, hue=hue, data=data, ax=ax)
ax.set_title(title)
ax.set_xlabel(xlabel)
# Get the scale factor from the formatter
formatter = ax.yaxis.get_major_formatter()
if hasattr(formatter, "orderOfMagnitude"):
scale = formatter.orderOfMagnitude
ylabel = f"{ylabel} ($1×10^{{{scale}}}$)"
ax.set_ylabel(ylabel)
# Rotate x-axis labels
for tick in ax.get_xticklabels():
tick.set_rotation(45)
tick.set_ha("right")
# Adjust x-tick positions
locs, _ = ax.get_xticks(), ax.get_xticklabels()
ax.set_xticks([loc + 0.2 for loc in locs])
# Create legend
legend = ax.legend(title=hue.title())
legend.get_title().set_weight("bold")
except Exception as e:
logger.error(f"Error creating yield plot: {str(e)}")
raise
def plot_sample_yields(
metrics_df: pl.DataFrame,
basecall_type: str = "sup",
figsize: Tuple[int, int] = (16, 6),
dpi: int = 300,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Create plots showing read and base yields for samples.
Args:
metrics_df (pl.DataFrame): Polars DataFrame containing metrics data
basecall_type (str, optional): Basecall type to filter for. Defaults to "sup".
figsize (Tuple[int, int], optional): Figure size. Defaults to (16, 6).
dpi (int, optional): Figure DPI. Defaults to 300.
gs (gridspec.GridSpec, optional): GridSpec for plotting within a larger figure.
Returns:
Optional[plt.Figure]: Figure object if created independently (no GridSpec provided)
Raises:
ValueError: If DataFrame doesn't contain required columns.
"""
try:
# Validate input data
required_cols = {
"basecall",
"anonymised_sample",
"number_of_reads",
"number_of_bases",
"multiplexing",
}
if not required_cols.issubset(metrics_df.columns):
missing = required_cols - set(metrics_df.columns)
raise ValueError(f"Missing required columns: {missing}")
# Filter data
yields_df = metrics_df.filter(pl.col("basecall") == basecall_type)
if len(yields_df) == 0:
raise ValueError(f"No data found for basecall_type: {basecall_type}")
# Create figure and axes based on whether GridSpec is provided
if gs is None:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
# Create plots
_create_yield_plot(
ax1,
yields_df,
"anonymised_sample",
"number_of_reads",
"multiplexing",
f"Read Yield per Sample",
"Sample",
"Number of Reads",
)
_create_yield_plot(
ax2,
yields_df,
"anonymised_sample",
"number_of_bases",
"multiplexing",
f"Base Yield per Sample",
"Sample",
"Number of Bases",
)
if gs is None:
plt.tight_layout()
return fig
return None
except Exception as e:
logger.error(f"Error plotting sample yields: {str(e)}")
raise
yields_plots = plot_sample_yields(nanoplot_qc_metrics_df)
Summary Stats¶
In [4]:
@dataclass
class YieldMetrics:
"""
Data class for storing yield metrics statistics.
"""
max: float
min: float
mean: float
std: float
median: float
@dataclass
class YieldStats:
"""
Data class for storing read and base yield statistics.
"""
reads: YieldMetrics
bases: YieldMetrics
def _format_number_separator(num: float) -> str:
"""
Format a number with thousand separators.
Args:
num (float): Number to format with thousand separators
Returns:
str: Formatted string representation of the number with thousand separators
Examples:
>>> _format_number_separator(1234567.89)
'1,234,568'
"""
return f"{num:,.0f}"
def _calculate_yield_stats(df: pl.DataFrame) -> YieldStats:
"""
Calculate yield statistics from a Polars DataFrame.
Args:
df (pl.DataFrame): Input DataFrame containing yield metrics with columns
'number_of_reads' and 'number_of_bases'
Returns:
YieldStats: Object containing read and base statistics
Raises:
Exception: If there's an error calculating statistics from the DataFrame
KeyError: If required columns are missing from the DataFrame
"""
try:
reads_metrics = YieldMetrics(
max=df.select(pl.col("number_of_reads").max()).item(),
min=df.select(pl.col("number_of_reads").min()).item(),
mean=df.select(pl.col("number_of_reads").mean()).item(),
std=df.select(pl.col("number_of_reads").std()).item(),
median=df.select(pl.col("number_of_reads").median()).item(),
)
bases_metrics = YieldMetrics(
max=df.select(pl.col("number_of_bases").max()).item(),
min=df.select(pl.col("number_of_bases").min()).item(),
mean=df.select(pl.col("number_of_bases").mean()).item(),
std=df.select(pl.col("number_of_bases").std()).item(),
median=df.select(pl.col("number_of_bases").median()).item(),
)
return YieldStats(reads=reads_metrics, bases=bases_metrics)
except Exception as e:
logger.error(f"Error calculating yield statistics: {str(e)}")
raise
def _print_yield_stats(stats: YieldStats, sample_type: str) -> None:
"""
Print formatted yield statistics.
Args:
stats (YieldStats): YieldStats object containing statistics to print
sample_type (str): Type of sample (Multiplexed/Singleplexed)
Raises:
Exception: If there's an error formatting or printing the statistics
"""
try:
logger.info(f"Printing statistics for {sample_type} samples")
print(f"\n{sample_type} Samples Statistics:")
print("=" * 40)
for metric_name, metrics in [("Reads", stats.reads), ("Bases", stats.bases)]:
print(f"\n{metric_name}:")
for stat_name, value in vars(metrics).items():
formatted_value = _format_number_separator(value)
print(f" {stat_name.capitalize():6s}: {formatted_value}")
except Exception as e:
logger.error(f"Error printing yield statistics: {str(e)}")
raise
def _calculate_percentage_increase(
singleplex_val: float, multiplex_val: float
) -> float:
"""
Calculate percentage increase between two values.
Args:
singleplex_val (float): Value from singleplex samples
multiplex_val (float): Value from multiplex samples
Returns:
float: Percentage increase between the two values
Raises:
ZeroDivisionError: If multiplex value is zero
Exception: For other calculation errors
Examples:
>>> _calculate_percentage_increase(200, 100)
100.0
"""
try:
return ((singleplex_val - multiplex_val) / multiplex_val) * 100
except ZeroDivisionError:
logger.error("Cannot calculate percentage increase: multiplex value is zero")
raise
except Exception as e:
logger.error(f"Error calculating percentage increase: {str(e)}")
raise
def analyze_yields(df: pl.DataFrame) -> None:
"""
Analyze and print yield statistics for multiplexed and singleplexed samples.
Args:
df (pl.DataFrame): Input DataFrame containing yield metrics with columns:
- multiplexing: str ('singleplex' or 'multiplex')
- basecall: str ('sup' or other)
- number_of_reads: int/float
- number_of_bases: int/float
Raises:
Exception: If there's an error during analysis
ValueError: If required data is missing from the DataFrame
"""
try:
singleplex_yields = df.filter(
(pl.col("multiplexing") == "singleplex") & (pl.col("basecall") == "sup")
)
multiplex_yields = df.filter(
(pl.col("multiplexing") == "multiplex") & (pl.col("basecall") == "sup")
)
if singleplex_yields.height == 0 or multiplex_yields.height == 0:
logger.warning("No data found for either singleplex or multiplex samples")
return
singleplex_stats = _calculate_yield_stats(singleplex_yields)
multiplex_stats = _calculate_yield_stats(multiplex_yields)
_print_yield_stats(singleplex_stats, "Singleplexed")
_print_yield_stats(multiplex_stats, "Multiplexed")
print("\nPercentage Increase (Singleplexed vs Multiplexed):")
print("=" * 40)
increase_reads = _calculate_percentage_increase(
singleplex_stats.reads.mean, multiplex_stats.reads.mean
)
increase_bases = _calculate_percentage_increase(
singleplex_stats.bases.mean, multiplex_stats.bases.mean
)
print(f"Mean Number of Reads: {increase_reads:6.2f}%")
print(f"Mean Number of Bases: {increase_bases:6.2f}%")
logger.info("Yield analysis completed successfully")
except Exception as e:
logger.error(f"Error in yield analysis: {str(e)}")
raise
analyze_yields(nanoplot_qc_metrics_df)
__main__ - INFO - Printing statistics for Singleplexed samples
__main__ - INFO - Printing statistics for Multiplexed samples
__main__ - INFO - Yield analysis completed successfully
Singleplexed Samples Statistics: ======================================== Reads: Max : 14,874,440 Min : 7,346,662 Mean : 9,932,389 Std : 2,541,662 Median: 8,804,416 Bases: Max : 89,965,577,035 Min : 47,190,301,151 Mean : 64,410,968,490 Std : 16,697,589,113 Median: 56,124,706,342 Multiplexed Samples Statistics: ======================================== Reads: Max : 7,088,011 Min : 3,278,883 Mean : 4,546,234 Std : 1,382,487 Median: 4,039,386 Bases: Max : 40,077,847,821 Min : 25,565,442,444 Mean : 32,214,905,852 Std : 5,350,881,127 Median: 32,395,219,886 Percentage Increase (Singleplexed vs Multiplexed): ======================================== Mean Number of Reads: 118.48% Mean Number of Bases: 99.94%
2. Read Lengths¶
In [5]:
def _create_length_subplot(
data: pl.DataFrame, ax: plt.Axes, title: str, hue: str
) -> None:
"""
Create a subplot showing read length metrics.
Args:
data (pl.DataFrame): DataFrame containing the plot data
ax (plt.Axes): Matplotlib axes object to plot on
title (str): Plot title
hue (str): Column name for color grouping
Raises:
ValueError: If required columns are missing from DataFrame
"""
try:
sns.barplot(
x="anonymised_sample",
y="read_length",
hue=hue,
data=data,
errorbar=None,
ax=ax,
)
ax.set_title(title)
ax.set_xlabel("Sample")
formatter = ax.yaxis.get_major_formatter()
ylabel = "Read Length (bp)"
ax.set_ylabel(ylabel)
for tick in ax.get_xticklabels():
tick.set_rotation(45)
tick.set_ha("right")
locs, _ = ax.get_xticks(), ax.get_xticklabels()
ax.set_xticks([loc + 0.2 for loc in locs])
legend = ax.legend(title=hue.title())
legend.get_title().set_weight("bold")
except Exception as e:
logger.error(f"Error creating length subplot: {str(e)}")
raise
def plot_read_lengths(
metrics_df: pl.DataFrame,
figsize: Tuple[int, int] = (16, 6),
dpi: int = 300,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Create plots showing read length distributions for samples.
Args:
metrics_df (pl.DataFrame): Input DataFrame containing metrics data
figsize (Tuple[int, int], optional): Figure size. Defaults to (16, 6).
dpi (int, optional): Figure DPI. Defaults to 300.
gs (gridspec.GridSpec, optional): GridSpec for plotting within a larger figure.
Returns:
Optional[plt.Figure]: Figure object if created independently.
Raises:
ValueError: If DataFrame doesn't contain required columns.
"""
try:
required_cols = {
"basecall",
"sample",
"anonymised_sample",
"multiplexing",
"mean_read_length",
"median_read_length",
}
if not required_cols.issubset(metrics_df.columns):
missing = required_cols - set(metrics_df.columns)
raise ValueError(f"Missing required columns: {missing}")
# Filter and prepare data
plot_data = (
metrics_df.filter(pl.col("basecall") == "sup")
.select(
[
"sample",
"anonymised_sample",
"multiplexing",
"mean_read_length",
"median_read_length",
]
)
.unpivot(
index=["sample", "anonymised_sample", "multiplexing"],
on=["mean_read_length", "median_read_length"],
variable_name="read_length_type",
value_name="read_length",
)
)
if gs is None:
fig, axes = plt.subplots(1, 2, figsize=figsize, sharey=True, dpi=dpi)
else:
fig = plt.gcf()
axes = [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 1])]
for ax, read_length_type in zip(
axes, ["mean_read_length", "median_read_length"]
):
title = read_length_type.replace("_", " ").title()
data = plot_data.filter(pl.col("read_length_type") == read_length_type)
_create_length_subplot(data, ax, title, hue="multiplexing")
if ax != axes[0]:
ax.yaxis.set_tick_params(labelleft=True)
if gs is None:
plt.tight_layout()
return fig
return None
except Exception as e:
logger.error(f"Error plotting read lengths: {str(e)}")
raise
read_lengths_plot = plot_read_lengths(nanoplot_qc_metrics_df)
In [6]:
def load_nanoplot_data(base_dir: Path, metrics_df: pl.DataFrame) -> pl.DataFrame:
"""
Load NanoPlot data from pickle files and combine with metrics.
Args:
base_dir (Path): Base directory containing NanoPlot data files
metrics_df (pl.DataFrame): DataFrame containing sample metrics
Returns:
pl.DataFrame: Combined NanoPlot data for all samples
Raises:
FileNotFoundError: If pickle file is not found
ValueError: If required columns are missing
"""
required_columns = ("readIDs", "quals", "lengths", "mapQ")
data_list = []
try:
for row in metrics_df.iter_rows(named=True):
sample_dir = f"{row['sample']}_{row['basecall']}"
pickle_path = base_dir / sample_dir / "NanoPlot-data.pickle"
if not pickle_path.is_file():
logger.warning(f"Pickle file not found: {pickle_path}")
continue
with open(pickle_path, "rb") as file:
nanoplot_data = pickle.load(file)
sample_df = pl.DataFrame(nanoplot_data).select(required_columns)
sample_df = sample_df.with_columns(
[
pl.lit(row["anonymised_sample"]).alias("anonymised_sample"),
pl.lit(row["basecall"]).alias("basecall"),
]
)
data_list.append(sample_df)
if not data_list:
raise ValueError("No valid data found in any pickle files")
return pl.concat(data_list)
except Exception as e:
logger.error(f"Error loading NanoPlot data: {str(e)}")
raise
def process_aligned_nanoplot_data(
nanoplot_data: pl.DataFrame, metrics_df: pl.DataFrame
) -> pl.DataFrame:
"""
Process NanoPlot data by merging with metrics and binning read lengths.
Args:
nanoplot_data (pl.DataFrame): Raw NanoPlot data
metrics_df (pl.DataFrame): Metrics DataFrame
Returns:
pl.DataFrame: Processed DataFrame with binned lengths
Raises:
ValueError: If required columns are missing
"""
try:
metrics_subset = metrics_df.select(
["anonymised_sample", "multiplexing", "basecall", "number_of_reads"]
)
processed_data = nanoplot_data.join(
metrics_subset, on=["anonymised_sample", "basecall"]
)
max_length = processed_data.select(pl.col("lengths").max()).item()
bins = np.logspace(np.log10(10), np.log10(max_length), num=100)
processed_data = processed_data.with_columns(
[pl.col("lengths").cut(bins).alias("length_bin")]
)
return processed_data
except Exception as e:
logger.error(f"Error processing NanoPlot data: {str(e)}")
raise
def calculate_read_length_distribution(
processed_data: pl.DataFrame, basecall_type: str = "sup"
) -> pl.DataFrame:
"""
Calculate read length distribution statistics.
Args:
processed_data (pl.DataFrame): Processed NanoPlot data
basecall_type (str, optional): Basecall type to filter. Defaults to "sup"
Returns:
pl.DataFrame: Length distribution statistics
Raises:
ValueError: If required columns are missing
"""
try:
filtered_data = processed_data.filter(pl.col("basecall") == basecall_type)
# Extract bin edges from the categorical length_bin column
bin_categories = filtered_data.select(pl.col("length_bin").unique()).to_series()
bin_edges = [float(edge.split(",")[0][1:]) for edge in bin_categories]
bin_centers = [
(bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(len(bin_edges) - 1)
]
# Create a mapping DataFrame to calculate bin centers
mapping_df = pl.DataFrame(
{
"length_bin": bin_categories[:-1], # Exclude the last bin edge
"bin_center": bin_centers,
}
)
length_dist = (
filtered_data.group_by(["anonymised_sample", "length_bin", "multiplexing"])
.agg(pl.len().alias("count"))
.join(
filtered_data.select(["anonymised_sample", "number_of_reads"]).unique(),
on="anonymised_sample",
)
)
# Add percentage and join with mapping DataFrame to get bin centers
length_dist = length_dist.with_columns(
[(pl.col("count") / pl.col("number_of_reads") * 100).alias("percentage")]
).join(mapping_df, on="length_bin", how="left")
return length_dist
except Exception as e:
logger.error(f"Error calculating length distribution: {str(e)}")
raise
def plot_read_length_distribution(
length_dist: pl.DataFrame,
max_length: float,
figsize: Tuple[int, int] = (14, 6),
dpi: int = 300,
x_scale: str = "log",
x_min: int = 10,
num_ticks: int = 20,
line_alpha: float = 0.8,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Plot the distribution of read lengths across samples.
Args:
length_dist (pl.DataFrame): DataFrame containing length distribution data
max_length (float): Maximum read length for x-axis limit
figsize (Tuple[int, int], optional): Figure size. Defaults to (14, 6).
dpi (int, optional): Figure DPI. Defaults to 300.
x_scale (str, optional): Scale for x-axis. Defaults to "log".
x_min (int, optional): Minimum x-axis value. Defaults to 10.
num_ticks (int, optional): Number of x-axis ticks. Defaults to 20.
line_alpha (float, optional): Line transparency. Defaults to 0.8.
gs (gridspec.GridSpec, optional): GridSpec for plotting within a larger figure.
Returns:
Optional[Figure]: Figure object if created independently.
Raises:
ValueError: If required columns are missing.
"""
try:
required_cols = {
"anonymised_sample",
"bin_center",
"percentage",
"multiplexing",
}
if not all(col in length_dist.columns for col in required_cols):
missing = required_cols - set(length_dist.columns)
raise ValueError(f"Missing required columns: {missing}")
if gs is None:
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
ax = fig.add_subplot(gs[0, 0])
# Filter and prepare data
non_zero_samples = (
length_dist.group_by("anonymised_sample")
.agg(pl.col("percentage").sum())
.filter(pl.col("percentage") > 0)
.select("anonymised_sample")
)
filtered_dist = (
length_dist.join(non_zero_samples, on="anonymised_sample")
.with_columns(
[
pl.col("anonymised_sample")
.str.extract(r"(\d+)")
.cast(pl.Int32)
.alias("sample_num")
]
)
.sort("sample_num")
.rename(
{
"multiplexing": r"$\mathbf{Multiplexing}$",
"anonymised_sample": r"$\mathbf{Sample}$",
}
)
)
sns.lineplot(
data=filtered_dist,
x="bin_center",
y="percentage",
hue=r"$\mathbf{Sample}$",
style=r"$\mathbf{Multiplexing}$",
alpha=line_alpha,
ax=ax,
)
ax.legend(loc="upper left", bbox_to_anchor=(1, 1.1))
ax.set_xscale(x_scale)
ax.set_xlabel("Read Length (bp)")
ax.set_ylabel("Proportion of Reads (%)")
ax.set_title("Distribution of Read Lengths")
# Set x-axis ticks
tick_positions = np.logspace(
np.log10(x_min), np.log10(max_length), num=num_ticks
)
ax.set_xticks(tick_positions)
ax.set_xticklabels([f"{int(tick):,}" for tick in tick_positions])
# Set axis limits
ax.set_xlim(left=x_min, right=max_length)
if gs is None:
plt.tight_layout()
return fig
return None
except Exception as e:
logger.error(f"Error plotting read length distribution: {str(e)}")
raise
nanoplot_aligned_metrics = load_nanoplot_data(
np_aligned_bams_dir, nanoplot_qc_metrics_df
)
processed_aligned_nanoplot_df = process_aligned_nanoplot_data(
nanoplot_aligned_metrics, nanoplot_qc_metrics_df
)
max_read_length = processed_aligned_nanoplot_df.select(pl.col("lengths").max()).item()
read_length_distribution = calculate_read_length_distribution(
processed_aligned_nanoplot_df
)
read_length_dist_plot = plot_read_length_distribution(
read_length_distribution, max_read_length
)
In [7]:
@dataclass
class LengthMetrics:
"""
Data class for storing read length metrics statistics.
"""
max: float
min: float
mean: float
std: float
median: float
def _calculate_length_stats(df: pl.DataFrame) -> LengthMetrics:
"""
Calculate read length statistics from a Polars DataFrame.
Args:
df (pl.DataFrame): Input DataFrame containing length metrics with column 'lengths'
Returns:
LengthMetrics: Object containing read length statistics
Raises:
Exception: If there's an error calculating statistics from the DataFrame
KeyError: If required column is missing from the DataFrame
"""
try:
return LengthMetrics(
max=df.select(pl.col("lengths").max()).item(),
min=df.select(pl.col("lengths").min()).item(),
mean=df.select(pl.col("lengths").mean()).item(),
std=df.select(pl.col("lengths").std()).item(),
median=df.select(pl.col("lengths").median()).item(),
)
except Exception as e:
logger.error(f"Error calculating length statistics: {str(e)}")
raise
def _print_length_stats(stats: LengthMetrics, sample_type: str) -> None:
"""
Print formatted read length statistics.
Args:
stats (LengthMetrics): LengthMetrics object containing statistics to print
sample_type (str): Type of sample (Multiplexed/Singleplexed)
Raises:
Exception: If there's an error formatting or printing the statistics
"""
try:
logger.info(f"Printing length statistics for {sample_type} samples")
print(f"\n{sample_type} Samples Statistics:")
print("=" * 40)
print("\nRead Lengths:")
for stat_name, value in vars(stats).items():
formatted_value = _format_number_separator(value)
print(f" {stat_name.capitalize():6s}: {formatted_value}")
except Exception as e:
logger.error(f"Error printing length statistics: {str(e)}")
raise
def analyze_lengths(df: pl.DataFrame) -> None:
"""
Analyze and print read length statistics for multiplexed and singleplexed samples.
Args:
df (pl.DataFrame): Input DataFrame containing length metrics with columns:
- multiplexing: str ('singleplex' or 'multiplex')
- basecall: str ('sup' or other)
- lengths: int/float
Raises:
Exception: If there's an error during analysis
ValueError: If required data is missing from the DataFrame
"""
try:
singleplex_lengths = df.filter(
(pl.col("multiplexing") == "singleplex") & (pl.col("basecall") == "sup")
)
multiplex_lengths = df.filter(
(pl.col("multiplexing") == "multiplex") & (pl.col("basecall") == "sup")
)
if singleplex_lengths.height == 0 or multiplex_lengths.height == 0:
logger.warning("No data found for either singleplex or multiplex samples")
return
singleplex_stats = _calculate_length_stats(singleplex_lengths)
multiplex_stats = _calculate_length_stats(multiplex_lengths)
_print_length_stats(singleplex_stats, "Singleplexed")
_print_length_stats(multiplex_stats, "Multiplexed")
print("\nPercentage Increase (Singleplexed vs Multiplexed):")
print("=" * 40)
for stat_name in ["mean", "median"]:
increase = _calculate_percentage_increase(
getattr(singleplex_stats, stat_name),
getattr(multiplex_stats, stat_name),
)
print(f"{stat_name.capitalize():6s} Read Length: {increase:6.2f}%")
logger.info("Length analysis completed successfully")
except Exception as e:
logger.error(f"Error in length analysis: {str(e)}")
raise
analyze_lengths(processed_aligned_nanoplot_df)
__main__ - INFO - Printing length statistics for Singleplexed samples
__main__ - INFO - Printing length statistics for Multiplexed samples
__main__ - INFO - Length analysis completed successfully
Singleplexed Samples Statistics: ======================================== Read Lengths: Max : 836,896 Min : 40 Mean : 6,485 Std : 6,155 Median: 5,369 Multiplexed Samples Statistics: ======================================== Read Lengths: Max : 387,118 Min : 40 Mean : 7,086 Std : 7,668 Median: 4,225 Percentage Increase (Singleplexed vs Multiplexed): ======================================== Mean Read Length: -8.48% Median Read Length: 27.08%
3. Combined Plots¶
In [8]:
def create_combined_yield_plot(
metrics_df: pl.DataFrame, figsize: Tuple[int, int] = (12, 10), dpi: int = 300
) -> plt.Figure:
"""
Create a combined plot showing yields, read lengths, and length distribution.
Args:
metrics_df (pl.DataFrame): DataFrame containing metrics data
figsize (Tuple[int, int], optional): Figure size. Defaults to (12, 12).
dpi (int, optional): DPI for the figure. Defaults to 300.
Returns:
plt.Figure: Combined figure object
Raises:
ValueError: If required data is missing
"""
try:
# Create figure and GridSpec
fig = plt.figure(figsize=figsize, dpi=dpi)
gs = fig.add_gridspec(3, 2, height_ratios=[0.6, 0.9, 1.2])
# Plot yields (A and B)
plot_sample_yields(metrics_df, gs=gs)
# Plot read lengths (C and D)
plot_read_lengths(
metrics_df, gs=gridspec.GridSpecFromSubplotSpec(1, 2, gs[1, :])
)
# Plot read length distribution (E)
plot_read_length_distribution(
read_length_distribution,
max_read_length,
gs=gridspec.GridSpecFromSubplotSpec(1, 1, gs[2, :]),
)
# Add panel labels
for i, label in enumerate(["A", "B", "C", "D", "E"]):
ax = fig.axes[i]
ax.text(
-0.05,
1.07,
label,
transform=ax.transAxes,
fontsize=12,
fontweight="bold",
va="top",
)
# Remove redundant legends
for ax in fig.axes[1:4]:
ax.get_legend().remove()
fig.set_constrained_layout(True)
return fig
except Exception as e:
logger.error(f"Error creating combined yield plot: {str(e)}")
raise
combined_yield_plot = create_combined_yield_plot(nanoplot_qc_metrics_df)
In [9]:
def calculate_base_quality_distribution(
processed_data: pl.DataFrame, basecall_type: str = "sup"
) -> pl.DataFrame:
"""
Calculate distribution of base quality scores across samples.
Args:
processed_data (pl.DataFrame): Processed NanoPlot data
basecall_type (str, optional): Basecall type to filter. Defaults to "sup"
Returns:
pl.DataFrame: Quality distribution statistics
Raises:
ValueError: If required columns are missing
"""
try:
filtered_data = processed_data.filter(pl.col("basecall") == basecall_type)
min_qual = filtered_data.select(pl.col("quals").min()).item()
max_qual = filtered_data.select(pl.col("quals").max()).item()
bins = np.arange(min_qual, max_qual + 0.5, 0.5)
quality_dist = (
filtered_data.with_columns([pl.col("quals").cut(bins).alias("quals_bin")])
.group_by(["anonymised_sample", "quals_bin", "multiplexing"])
.agg(pl.len().alias("count"))
.join(
filtered_data.select(["anonymised_sample", "number_of_reads"]).unique(),
on="anonymised_sample",
)
)
# Calculate bin centers and percentages
bin_categories = quality_dist.select(pl.col("quals_bin").unique()).to_series()
bin_edges = [float(edge.split(",")[0][1:]) for edge in bin_categories]
bin_centers = [
(bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(len(bin_edges) - 1)
]
mapping_df = pl.DataFrame(
{"quals_bin": bin_categories, "quals_bin_lower": bin_edges}
)
quality_dist = quality_dist.with_columns(
[(pl.col("count") / pl.col("number_of_reads") * 100).alias("percentage")]
).join(
mapping_df,
on="quals_bin",
how="left",
)
return quality_dist
except Exception as e:
logger.error(f"Error calculating quality distribution: {str(e)}")
raise
def plot_base_quality_distribution(
quality_dist: pl.DataFrame,
figsize: Tuple[int, int] = (14, 6),
dpi: int = 300,
line_alpha: float = 0.8,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Plot the distribution of base quality scores across samples.
Args:
quality_dist (pl.DataFrame): DataFrame containing quality distribution data
figsize (Tuple[int, int], optional): Figure size. Defaults to (14, 6).
dpi (int, optional): Figure DPI. Defaults to 300.
line_alpha (float, optional): Line transparency. Defaults to 0.8.
gs (gridspec.GridSpec, optional): GridSpec for plotting within a larger figure.
Returns:
Optional[Figure]: Figure object if created independently.
Raises:
ValueError: If required columns are missing
"""
try:
required_cols = {
"anonymised_sample",
"quals_bin_lower",
"percentage",
"multiplexing",
}
if not all(col in quality_dist.columns for col in required_cols):
missing = required_cols - set(quality_dist.columns)
raise ValueError(f"Missing required columns: {missing}")
if gs is None:
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
ax = fig.add_subplot(gs[0, 0])
plot_data = (
quality_dist.with_columns(
[
pl.col("anonymised_sample")
.str.extract(r"(\d+)")
.cast(pl.Int32)
.alias("sample_num")
]
)
.sort("sample_num")
.rename(
{
"multiplexing": r"$\mathbf{Multiplexing}$",
"anonymised_sample": r"$\mathbf{Sample}$",
}
)
)
sns.lineplot(
data=plot_data,
x="quals_bin_lower",
y="percentage",
hue=r"$\mathbf{Sample}$",
style=r"$\mathbf{Multiplexing}$",
alpha=line_alpha,
ax=ax,
)
ax.legend(loc="upper right", bbox_to_anchor=(1.02, 1.1))
ax.set_xlabel("Quality Score")
ax.set_ylabel("Proportion of Reads (%)")
ax.set_title("Distribution of Base Quality Scores")
max_qual = int(plot_data["quals_bin_lower"].max())
tick_positions = np.arange(0, max_qual + 1, 5)
ax.set_xticks(tick_positions)
ax.set_xticklabels(tick_positions)
if gs is None:
plt.tight_layout()
return fig
else:
ax.get_legend().remove()
return None
except Exception as e:
logger.error(f"Error plotting quality distribution: {str(e)}")
raise
base_quality_distribution_df = calculate_base_quality_distribution(
processed_aligned_nanoplot_df
)
base_quality_dist_plot = plot_base_quality_distribution(base_quality_distribution_df)
In [10]:
def prepare_qscore_percentage_data(metrics_df: pl.DataFrame) -> pl.DataFrame:
"""
Prepare QScore percentage data for visualization.
Args:
metrics_df (pl.DataFrame): DataFrame containing QC metrics
Returns:
pl.DataFrame: Processed QScore percentage data
Raises:
ValueError: If required columns are missing
"""
try:
qscore_columns = [
"Reads_>Q5",
"Reads_>Q7",
"Reads_>Q10",
"Reads_>Q12",
"Reads_>Q15",
]
# Verify required columns exist
missing_cols = [col for col in qscore_columns if col not in metrics_df.columns]
if missing_cols:
raise ValueError(f"Missing required columns: {missing_cols}")
# Filter and unpivot DataFrame
qscore_df = (
metrics_df.filter(pl.col("basecall") == "sup")
.select(
[
"anonymised_sample",
"multiplexing",
*qscore_columns,
]
)
.unpivot(
index=["anonymised_sample", "multiplexing"],
on=qscore_columns,
variable_name="Quality_Score",
value_name="Percentage",
)
.with_columns(
[
# Extract just the Qn part from the Quality_Score column
pl.col("Quality_Score")
.str.extract(r">Q(\d+)")
.map_elements(
lambda x: f"Q{x}" if x is not None else None,
return_dtype=pl.Utf8,
)
.alias("Quality_Score")
]
)
)
return qscore_df
except Exception as e:
logger.error(f"Error preparing QScore percentage data: {str(e)}")
raise
def plot_qscore_percentage(
qscore_df: pl.DataFrame,
figsize: Tuple[int, int] = (20, 6),
dpi: int = 300,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Plot QScore percentage distribution across samples.
Args:
qscore_df (pl.DataFrame): DataFrame containing QScore percentage data
figsize (Tuple[int, int], optional): Figure size. Defaults to (20, 6).
dpi (int, optional): Figure DPI. Defaults to 300.
gs (gridspec.GridSpec, optional): GridSpec for plotting within a larger figure.
Returns:
Optional[Figure]: Figure object if created independently.
Raises:
ValueError: If required columns are missing
"""
try:
required_cols = {
"anonymised_sample",
"multiplexing",
"Quality_Score",
"Percentage",
}
if not all(col in qscore_df.columns for col in required_cols):
missing = required_cols - set(qscore_df.columns)
raise ValueError(f"Missing required columns: {missing}")
if gs is None:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1])
plot_data = qscore_df.with_columns(
[
pl.col("anonymised_sample")
.str.extract(r"(\d+)")
.cast(pl.Int32)
.alias("sample_num"),
pl.col("anonymised_sample").alias(r"$\mathbf{Sample}$"),
]
).sort("sample_num")
quality_score_order = ["Q5", "Q7", "Q10", "Q12", "Q15"]
for ax, multiplex_type in zip([ax1, ax2], ["multiplex", "singleplex"]):
data = plot_data.filter(pl.col("multiplexing") == multiplex_type)
sns.barplot(
data=data,
x=r"$\mathbf{Sample}$",
y="Percentage",
hue="Quality_Score",
errorbar=None,
ax=ax,
hue_order=quality_score_order,
)
ax.set_xlabel("Sample")
ax.set_ylabel("Proportion of Reads (%)")
ax.set_title(
f"Percentage of Reads Above Quality Scores\n{multiplex_type.capitalize()} Samples"
)
legend = (
ax.legend(title="Quality Score", loc="lower right")
.get_title()
.set_fontweight("bold")
)
for tick in ax.get_xticklabels():
tick.set_rotation(45)
tick.set_ha("right")
locs, labels = ax.get_xticks(), ax.get_xticklabels()
ax.set_xticks([loc + 0.1 for loc in locs])
if gs is None:
plt.tight_layout()
return fig
else:
ax1.get_legend().remove()
ax2.legend(
title="Quality Score", bbox_to_anchor=(1, 1.05), loc="upper left"
).get_title().set_fontweight("bold")
return None
except Exception as e:
logger.error(f"Error plotting QScore percentage distribution: {str(e)}")
raise
qscore_percentage_df = prepare_qscore_percentage_data(nanoplot_qc_metrics_df)
qscore_percentage_plot = plot_qscore_percentage(qscore_percentage_df)
In [11]:
with pl.Config(tbl_rows=len(qscore_percentage_df)):
display(qscore_percentage_df)
shape: (70, 4)
| anonymised_sample | multiplexing | Quality_Score | Percentage |
|---|---|---|---|
| str | str | str | f64 |
| "Sample 1" | "multiplex" | "Q5" | 0.999 |
| "Sample 2" | "multiplex" | "Q5" | 0.999 |
| "Sample 3" | "multiplex" | "Q5" | 0.999 |
| "Sample 4" | "multiplex" | "Q5" | 0.999 |
| "Sample 5" | "multiplex" | "Q5" | 0.999 |
| "Sample 6" | "multiplex" | "Q5" | 0.999 |
| "Sample 7" | "singleplex" | "Q5" | 0.999 |
| "Sample 8" | "singleplex" | "Q5" | 1.0 |
| "Sample 9" | "singleplex" | "Q5" | 0.999 |
| "Sample 10" | "singleplex" | "Q5" | 0.999 |
| "Sample 11" | "singleplex" | "Q5" | 1.0 |
| "Sample 12" | "singleplex" | "Q5" | 1.0 |
| "Sample 13" | "singleplex" | "Q5" | 1.0 |
| "Sample 14" | "singleplex" | "Q5" | 1.0 |
| "Sample 1" | "multiplex" | "Q7" | 0.998 |
| "Sample 2" | "multiplex" | "Q7" | 0.999 |
| "Sample 3" | "multiplex" | "Q7" | 0.998 |
| "Sample 4" | "multiplex" | "Q7" | 0.998 |
| "Sample 5" | "multiplex" | "Q7" | 0.998 |
| "Sample 6" | "multiplex" | "Q7" | 0.998 |
| "Sample 7" | "singleplex" | "Q7" | 0.999 |
| "Sample 8" | "singleplex" | "Q7" | 0.999 |
| "Sample 9" | "singleplex" | "Q7" | 0.998 |
| "Sample 10" | "singleplex" | "Q7" | 0.998 |
| "Sample 11" | "singleplex" | "Q7" | 0.999 |
| "Sample 12" | "singleplex" | "Q7" | 0.999 |
| "Sample 13" | "singleplex" | "Q7" | 0.999 |
| "Sample 14" | "singleplex" | "Q7" | 0.999 |
| "Sample 1" | "multiplex" | "Q10" | 0.993 |
| "Sample 2" | "multiplex" | "Q10" | 0.994 |
| "Sample 3" | "multiplex" | "Q10" | 0.993 |
| "Sample 4" | "multiplex" | "Q10" | 0.993 |
| "Sample 5" | "multiplex" | "Q10" | 0.991 |
| "Sample 6" | "multiplex" | "Q10" | 0.991 |
| "Sample 7" | "singleplex" | "Q10" | 0.994 |
| "Sample 8" | "singleplex" | "Q10" | 0.995 |
| "Sample 9" | "singleplex" | "Q10" | 0.993 |
| "Sample 10" | "singleplex" | "Q10" | 0.992 |
| "Sample 11" | "singleplex" | "Q10" | 0.995 |
| "Sample 12" | "singleplex" | "Q10" | 0.995 |
| "Sample 13" | "singleplex" | "Q10" | 0.995 |
| "Sample 14" | "singleplex" | "Q10" | 0.995 |
| "Sample 1" | "multiplex" | "Q12" | 0.977 |
| "Sample 2" | "multiplex" | "Q12" | 0.978 |
| "Sample 3" | "multiplex" | "Q12" | 0.978 |
| "Sample 4" | "multiplex" | "Q12" | 0.979 |
| "Sample 5" | "multiplex" | "Q12" | 0.969 |
| "Sample 6" | "multiplex" | "Q12" | 0.973 |
| "Sample 7" | "singleplex" | "Q12" | 0.976 |
| "Sample 8" | "singleplex" | "Q12" | 0.983 |
| "Sample 9" | "singleplex" | "Q12" | 0.978 |
| "Sample 10" | "singleplex" | "Q12" | 0.975 |
| "Sample 11" | "singleplex" | "Q12" | 0.98 |
| "Sample 12" | "singleplex" | "Q12" | 0.981 |
| "Sample 13" | "singleplex" | "Q12" | 0.981 |
| "Sample 14" | "singleplex" | "Q12" | 0.979 |
| "Sample 1" | "multiplex" | "Q15" | 0.892 |
| "Sample 2" | "multiplex" | "Q15" | 0.883 |
| "Sample 3" | "multiplex" | "Q15" | 0.892 |
| "Sample 4" | "multiplex" | "Q15" | 0.89 |
| "Sample 5" | "multiplex" | "Q15" | 0.857 |
| "Sample 6" | "multiplex" | "Q15" | 0.87 |
| "Sample 7" | "singleplex" | "Q15" | 0.877 |
| "Sample 8" | "singleplex" | "Q15" | 0.905 |
| "Sample 9" | "singleplex" | "Q15" | 0.892 |
| "Sample 10" | "singleplex" | "Q15" | 0.879 |
| "Sample 11" | "singleplex" | "Q15" | 0.882 |
| "Sample 12" | "singleplex" | "Q15" | 0.893 |
| "Sample 13" | "singleplex" | "Q15" | 0.884 |
| "Sample 14" | "singleplex" | "Q15" | 0.878 |
In [12]:
@dataclass
class QualityMetrics:
"""
Data class for storing base quality metrics statistics.
"""
max: float
min: float
mean: float
std: float
median: float
def _calculate_quality_stats(df: pl.DataFrame) -> QualityMetrics:
"""
Calculate base quality statistics from a Polars DataFrame.
Args:
df (pl.DataFrame): Input DataFrame containing quality metrics with column 'quals'
Returns:
QualityMetrics: Object containing base quality statistics
Raises:
Exception: If there's an error calculating statistics from the DataFrame
KeyError: If required column is missing from the DataFrame
"""
try:
return QualityMetrics(
max=df.select(pl.col("quals").max()).item(),
min=df.select(pl.col("quals").min()).item(),
mean=df.select(pl.col("quals").mean()).item(),
std=df.select(pl.col("quals").std()).item(),
median=df.select(pl.col("quals").median()).item(),
)
except Exception as e:
logger.error(f"Error calculating quality statistics: {str(e)}")
raise
def _print_quality_stats(stats: QualityMetrics, sample_type: str) -> None:
"""
Print formatted base quality statistics.
Args:
stats (QualityMetrics): QualityMetrics object containing statistics to print
sample_type (str): Type of sample (Multiplexed/Singleplexed)
Raises:
Exception: If there's an error formatting or printing the statistics
"""
try:
logger.info(f"Printing quality statistics for {sample_type} samples")
print(f"\n{sample_type} Samples Statistics:")
print("=" * 40)
print("\nBase Qualities:")
for stat_name, value in vars(stats).items():
print(f" {stat_name.capitalize():6s}: {value:.2f}")
except Exception as e:
logger.error(f"Error printing quality statistics: {str(e)}")
raise
def analyze_basecall_quality(df: pl.DataFrame) -> None:
"""
Analyze and print base quality statistics for multiplexed and singleplexed samples.
Args:
df (pl.DataFrame): Input DataFrame containing quality metrics with columns:
- multiplexing: str ('singleplex' or 'multiplex')
- basecall: str ('sup' or other)
- quals: int/float
Raises:
Exception: If there's an error during analysis
ValueError: If required data is missing from the DataFrame
"""
try:
singleplex_quals = df.filter(
(pl.col("multiplexing") == "singleplex") & (pl.col("basecall") == "sup")
)
multiplex_quals = df.filter(
(pl.col("multiplexing") == "multiplex") & (pl.col("basecall") == "sup")
)
if singleplex_quals.height == 0 or multiplex_quals.height == 0:
logger.warning("No data found for either singleplex or multiplex samples")
return
singleplex_stats = _calculate_quality_stats(singleplex_quals)
multiplex_stats = _calculate_quality_stats(multiplex_quals)
_print_quality_stats(singleplex_stats, "Singleplexed")
_print_quality_stats(multiplex_stats, "Multiplexed")
print("\nPercentage Increase (Singleplexed vs Multiplexed):")
print("=" * 40)
for stat_name in ["mean", "median"]:
increase = _calculate_percentage_increase(
getattr(singleplex_stats, stat_name),
getattr(multiplex_stats, stat_name),
)
print(f"{stat_name.capitalize():6s} Base Quality: {increase:6.2f}%")
logger.info("Quality analysis completed successfully")
except Exception as e:
logger.error(f"Error in quality analysis: {str(e)}")
raise
analyze_basecall_quality(processed_aligned_nanoplot_df)
__main__ - INFO - Printing quality statistics for Singleplexed samples
__main__ - INFO - Printing quality statistics for Multiplexed samples
__main__ - INFO - Quality analysis completed successfully
Singleplexed Samples Statistics: ======================================== Base Qualities: Max : 49.64 Min : 1.88 Mean : 20.64 Std : 4.70 Median: 20.60 Multiplexed Samples Statistics: ======================================== Base Qualities: Max : 49.65 Min : 1.98 Mean : 20.45 Std : 4.69 Median: 20.34 Percentage Increase (Singleplexed vs Multiplexed): ======================================== Mean Base Quality: 0.95% Median Base Quality: 1.24%
2. Mapping Quality¶
In [13]:
def calculate_mapping_quality_distribution(
processed_data: pl.DataFrame, basecall_type: str = "sup"
) -> pl.DataFrame:
"""
Calculate distribution of mapping quality scores across samples.
Args:
processed_data (pl.DataFrame): Processed NanoPlot data
basecall_type (str, optional): Basecall type to filter. Defaults to "sup"
Returns:
pl.DataFrame: Mapping quality distribution statistics
Raises:
ValueError: If required columns are missing
"""
try:
filtered_data = processed_data.filter(
pl.col("basecall") == basecall_type
).with_columns(pl.col("mapQ"), pl.col("number_of_reads"))
max_mapq = filtered_data.select(pl.col("mapQ").max()).item()
bins = np.arange(0, max_mapq + 0.5, 0.5)
mapping_dist = (
filtered_data.with_columns([pl.col("mapQ").cut(bins).alias("mapQ_bin")])
.group_by(["anonymised_sample", "mapQ_bin", "multiplexing"])
.agg(pl.len().alias("count"))
.join(
filtered_data.group_by("anonymised_sample").agg(
pl.first("number_of_reads").alias("number_of_reads")
),
on="anonymised_sample",
)
)
# Extract lower boundary of each bin
bin_categories = mapping_dist.select(pl.col("mapQ_bin").unique()).to_series()
bin_edges = [float(edge.split(",")[0][1:]) for edge in bin_categories]
mapping_df = pl.DataFrame(
{
"mapQ_bin": bin_categories,
"mapQ_bin_lower": bin_edges,
}
)
mapping_dist = mapping_dist.with_columns(
[(pl.col("count") / pl.col("number_of_reads") * 100).alias("percentage")]
).join(mapping_df, on="mapQ_bin", how="left")
return mapping_dist
except Exception as e:
logger.error(f"Error calculating mapping quality distribution: {str(e)}")
raise
except Exception as e:
logger.error(f"Error calculating mapping quality distribution: {str(e)}")
raise
def plot_mapping_quality_distribution(
mapping_dist: pl.DataFrame,
figsize: Tuple[int, int] = (14, 6),
dpi: int = 300,
line_alpha: float = 0.8,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Plot the distribution of mapping quality scores across samples.
Args:
mapping_dist (pl.DataFrame): DataFrame containing mapping quality distribution data
figsize (Tuple[int, int], optional): Figure size. Defaults to (14, 6).
dpi (int, optional): Figure DPI. Defaults to 300.
line_alpha (float, optional): Line transparency. Defaults to 0.8.
gs (gridspec.GridSpec, optional): GridSpec for plotting within a larger figure.
Returns:
Optional[Figure]: Figure object if created independently.
Raises:
ValueError: If required columns are missing
"""
try:
required_cols = {
"anonymised_sample",
"mapQ_bin_lower",
"percentage",
"multiplexing",
}
if not all(col in mapping_dist.columns for col in required_cols):
missing = required_cols - set(mapping_dist.columns)
raise ValueError(f"Missing required columns: {missing}")
if gs is None:
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
ax = fig.add_subplot(gs[0, 0])
plot_data = (
mapping_dist.with_columns(
[
pl.col("anonymised_sample")
.str.extract(r"(\d+)")
.cast(pl.Int32)
.alias("sample_num")
]
)
.sort("sample_num")
.rename(
{
"multiplexing": r"$\mathbf{Multiplexing}$",
"anonymised_sample": r"$\mathbf{Sample}$",
}
)
)
sns.lineplot(
data=plot_data,
x="mapQ_bin_lower",
y="percentage",
hue=r"$\mathbf{Sample}$",
style=r"$\mathbf{Multiplexing}$",
alpha=line_alpha,
ax=ax,
)
ax.legend(loc="upper right", bbox_to_anchor=(1.02, 1.1))
ax.set_xlabel("Mapping Quality Score")
ax.set_ylabel("Proportion of Reads (%)")
ax.set_title("Distribution of Mapping Quality Scores")
max_mapq = int(plot_data["mapQ_bin_lower"].max())
tick_positions = np.arange(0, max_mapq + 1, 5)
ax.set_xticks(tick_positions)
ax.set_xticklabels(tick_positions)
if gs is None:
plt.tight_layout()
return fig
else:
ax.legend(bbox_to_anchor=(1.05, 1.05), loc="upper left")
return None
except Exception as e:
logger.error(f"Error plotting mapping quality distribution: {str(e)}")
raise
mapping_quality_distribution_df = calculate_mapping_quality_distribution(
processed_aligned_nanoplot_df
)
mapping_quality_dist_plot = plot_mapping_quality_distribution(
mapping_quality_distribution_df
)
In [14]:
@dataclass
class MappingMetrics:
"""
Data class for storing mapping quality metrics statistics.
"""
max: float
min: float
mean: float
std: float
median: float
def _calculate_mapping_stats(df: pl.DataFrame) -> MappingMetrics:
"""
Calculate mapping quality statistics from a Polars DataFrame.
Args:
df (pl.DataFrame): Input DataFrame containing mapping quality metrics with column 'mapQ'
Returns:
MappingMetrics: Object containing mapping quality statistics
Raises:
Exception: If there's an error calculating statistics from the DataFrame
KeyError: If required column is missing from the DataFrame
"""
try:
return MappingMetrics(
max=df.select(pl.col("mapQ").max()).item(),
min=df.select(pl.col("mapQ").min()).item(),
mean=df.select(pl.col("mapQ").mean()).item(),
std=df.select(pl.col("mapQ").std()).item(),
median=df.select(pl.col("mapQ").median()).item(),
)
except Exception as e:
logger.error(f"Error calculating mapping statistics: {str(e)}")
raise
def _print_mapping_stats(stats: MappingMetrics, sample_type: str) -> None:
"""
Print formatted mapping quality statistics.
Args:
stats (MappingMetrics): MappingMetrics object containing statistics to print
sample_type (str): Type of sample (Multiplexed/Singleplexed)
Raises:
Exception: If there's an error formatting or printing the statistics
"""
try:
logger.info(f"Printing mapping statistics for {sample_type} samples")
print(f"\n{sample_type} Samples Statistics:")
print("=" * 40)
print("\nMapping Quality:")
for stat_name, value in vars(stats).items():
print(f" {stat_name.capitalize():6s}: {value:.2f}")
except Exception as e:
logger.error(f"Error printing mapping statistics: {str(e)}")
raise
def analyze_mapping_quality(df: pl.DataFrame) -> None:
"""
Analyze and print mapping quality statistics for multiplexed and singleplexed samples.
Args:
df (pl.DataFrame): Input DataFrame containing mapping metrics with columns:
- multiplexing: str ('singleplex' or 'multiplex')
- basecall: str ('sup' or other)
- mapQ: int/float
Raises:
Exception: If there's an error during analysis
ValueError: If required data is missing from the DataFrame
"""
try:
singleplex_quals = df.filter(
(pl.col("multiplexing") == "singleplex") & (pl.col("basecall") == "sup")
)
multiplex_quals = df.filter(
(pl.col("multiplexing") == "multiplex") & (pl.col("basecall") == "sup")
)
if singleplex_quals.height == 0 or multiplex_quals.height == 0:
logger.warning("No data found for either singleplex or multiplex samples")
return
singleplex_stats = _calculate_mapping_stats(singleplex_quals)
multiplex_stats = _calculate_mapping_stats(multiplex_quals)
_print_mapping_stats(singleplex_stats, "Singleplexed")
_print_mapping_stats(multiplex_stats, "Multiplexed")
print("\nPercentage Increase (Singleplexed vs Multiplexed):")
print("=" * 40)
for stat_name in ["mean", "median"]:
increase = _calculate_percentage_increase(
getattr(singleplex_stats, stat_name),
getattr(multiplex_stats, stat_name),
)
print(f"{stat_name.capitalize():6s} Mapping Quality: {increase:6.2f}%")
logger.info("Mapping analysis completed successfully")
except Exception as e:
logger.error(f"Error in mapping analysis: {str(e)}")
raise
analyze_mapping_quality(processed_aligned_nanoplot_df)
__main__ - INFO - Printing mapping statistics for Singleplexed samples
__main__ - INFO - Printing mapping statistics for Multiplexed samples
__main__ - INFO - Mapping analysis completed successfully
Singleplexed Samples Statistics: ======================================== Mapping Quality: Max : 60.00 Min : 0.00 Mean : 56.71 Std : 12.67 Median: 60.00 Multiplexed Samples Statistics: ======================================== Mapping Quality: Max : 60.00 Min : 0.00 Mean : 56.68 Std : 12.75 Median: 60.00 Percentage Increase (Singleplexed vs Multiplexed): ======================================== Mean Mapping Quality: 0.05% Median Mapping Quality: 0.00%
3. Combined Plots¶
In [15]:
def create_combined_quality_plot(
quality_distribution: pl.DataFrame,
mapping_quality_distribution: pl.DataFrame,
qscore_df: pl.DataFrame,
figsize: Tuple[int, int] = (12, 8),
dpi: int = 300,
) -> plt.Figure:
try:
fig = plt.figure(figsize=figsize, dpi=dpi)
gs = fig.add_gridspec(2, 2)
# Plot base quality distribution (A)
plot_base_quality_distribution(
quality_distribution,
gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[0, 0]),
)
# Plot mapping quality distribution (B)
plot_mapping_quality_distribution(
mapping_quality_distribution,
gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[0, 1]),
)
# Plot QScore percentages (C and D)
plot_qscore_percentage(
qscore_df, gs=gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs[1, :])
)
# Add panel labels
for i, label in enumerate(["A", "B", "C", "D"]):
ax = fig.axes[i]
ax.text(
-0.1,
1.05,
label,
transform=ax.transAxes,
fontsize=12,
fontweight="bold",
va="top",
)
fig.set_constrained_layout(True)
return fig
except Exception as e:
logger.error(f"Error creating combined quality plot: {str(e)}")
raise
combined_quality_plot = create_combined_quality_plot(
base_quality_distribution_df, mapping_quality_distribution_df, qscore_percentage_df
)
In [16]:
def process_mosdepth_file(file_path: Path, suffix: str) -> pl.DataFrame:
"""Process a mosdepth summary file and return relevant depth statistics.
Args:
file_path: Path to the mosdepth summary file
suffix: Suffix to remove from sample names
Returns:
DataFrame containing processed depth statistics with columns: chrom, mean, sample
"""
# sample_name will have the suffix removed.
sample_name = file_path.name.split(".")[0].replace(suffix, "")
df = pl.read_csv(file_path, separator="\t")
chromosomes = [f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"]
df = df.filter(
(pl.col("chrom").is_in(chromosomes + ["total"]))
& ~pl.col("chrom").str.ends_with("_region")
)
df = df.select(["chrom", "mean"])
df = df.with_columns(pl.lit(sample_name).alias("sample"))
return df
def process_per_base_file(file_path: Path, suffix: str) -> pl.DataFrame:
"""Process a mosdepth per-base file and calculate statistics per chromosome.
Args:
file_path: Path to the per-base depth file
suffix: Suffix to remove from sample names
Returns:
DataFrame with per-chromosome statistics including mean depth and standard error
"""
# sample_name will have the suffix removed.
sample_name = file_path.name.split(".")[0].replace(suffix, "")
df = pl.read_csv(
file_path,
separator="\t",
has_header=False,
new_columns=["chrom", "start", "end", "depth"],
)
chromosomes = [f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"]
return (
df.filter(pl.col("chrom").is_in(chromosomes))
.group_by("chrom")
.agg(
[
pl.col("depth").mean().alias("mean"),
(pl.col("depth").std() / pl.col("depth").count().sqrt()).alias("sem"),
]
)
.with_columns(pl.lit(sample_name).alias("sample"))
)
def analyze_mosdepth_data(
metrics_df: pl.DataFrame,
summary_files: List[Path],
per_base_files: List[Path],
) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""Analyze mosdepth data from summary and per-base files.
Args:
metrics_df: DataFrame containing sample metrics and metadata
summary_files: List of paths to mosdepth summary files
per_base_files: List of paths to mosdepth per-base files
Returns:
Tuple containing:
- DataFrame with per-chromosome depth statistics
- DataFrame with total depth statistics
Raises:
FileNotFoundError: If no mosdepth files are found
"""
try:
if not summary_files or not per_base_files:
raise FileNotFoundError("No mosdepth files found")
logger.info(f"Processing {len(summary_files)} mosdepth summary files")
all_dfs = [
process_mosdepth_file(file, basecall_suffix) for file in summary_files
]
depth_df = pl.concat(all_dfs)
logger.info(f"Processing {len(per_base_files)} per-base files")
all_per_base_dfs = [
process_per_base_file(file, basecall_suffix) for file in per_base_files
]
per_base_df = pl.concat(all_per_base_dfs)
# Join per-base stats (using both "chrom" and "sample" so that the right rows merge)
depth_df = depth_df.join(
per_base_df.rename({"mean": "per_base_mean", "sem": "per_base_sem"}),
on=["chrom", "sample"],
how="left",
)
total_depth_df = (
depth_df.filter(pl.col("chrom") == "total")
.unique(subset="sample")
.select(["sample", "mean"])
)
total_depth_df = (
total_depth_df.rename({"mean": "mean_depth"})
.join(
metrics_df.select(["sample", "multiplexing", "anonymised_sample"]),
on="sample",
)
.sort(["multiplexing", "sample"])
)
depth_df = depth_df.filter(pl.col("chrom") != "total")
chromosome_order = [f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"]
depth_df = depth_df.with_columns(pl.col("chrom").cast(pl.Categorical))
depth_df = depth_df.join(
metrics_df.select(["sample", "multiplexing", "anonymised_sample"]),
on="sample",
how="left",
)
# Create sample_num by extracting the digit from the anonymised sample name
depth_df = depth_df.with_columns(
pl.col("anonymised_sample")
.str.extract(r"(\d+)")
.cast(pl.Int32)
.alias("sample_num")
).sort("sample_num")
wg_depth_df = (
depth_df.filter(pl.col("chrom").is_in(chromosome_order))
.unique(subset=["chrom", "anonymised_sample"])
.sort(["anonymised_sample", "multiplexing"])
)
logger.info(
f"Successfully processed depth data for {wg_depth_df.get_column('anonymised_sample').n_unique()} samples"
)
return wg_depth_df, total_depth_df
except Exception as e:
logger.error(f"Error analyzing mosdepth data: {str(e)}")
raise
def plot_mean_depth_per_chromosome(
wg_depth_df: pl.DataFrame,
figsize: Tuple[int, int] = (14, 6),
dpi: int = 300,
line_alpha: float = 0.95,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Plot mean depth per chromosome with standard error of the mean.
Args:
wg_depth_df (pl.DataFrame): DataFrame containing whole-genome depth statistics
figsize (Tuple[int, int], optional): Figure size. Defaults to (14, 6)
dpi (int, optional): Figure DPI. Defaults to 300
line_alpha (float, optional): Line transparency. Defaults to 0.8
gs (gridspec.GridSpec, optional): GridSpec for plotting within a larger figure
Returns:
Optional[plt.Figure]: Figure object if created independently
Raises:
ValueError: If required columns are missing
"""
try:
required_cols = {
"anonymised_sample",
"chrom",
"mean",
"per_base_sem",
"multiplexing",
}
if not all(col in wg_depth_df.columns for col in required_cols):
missing = required_cols - set(wg_depth_df.columns)
raise ValueError(f"Missing required columns: {missing}")
if gs is None:
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
ax = fig.add_subplot(gs[0, 0])
plot_df = (
wg_depth_df.with_columns(
pl.col("anonymised_sample")
.str.extract(r"(\d+)")
.cast(pl.Int32)
.alias("sample_num")
)
.sort(["sample_num", "chrom"])
.rename(
{
"multiplexing": r"$\mathbf{Multiplexing}$",
"anonymised_sample": r"$\mathbf{Sample}$",
}
)
)
unique_samples = list(dict.fromkeys(plot_df[r"$\mathbf{Sample}$"].to_list()))
color_palette = sns.color_palette("husl", n_colors=len(unique_samples))
color_dict = dict(zip(unique_samples, color_palette))
sns.lineplot(
data=plot_df,
x="chrom",
y="mean",
hue=r"$\mathbf{Sample}$",
style=r"$\mathbf{Multiplexing}$",
legend="full",
palette=color_dict,
hue_order=unique_samples,
alpha=line_alpha,
ax=ax,
)
for sample in unique_samples:
sample_df = plot_df.filter(pl.col(r"$\mathbf{Sample}$") == sample)
ax.fill_between(
sample_df["chrom"],
sample_df["mean"] - sample_df["per_base_sem"],
sample_df["mean"] + sample_df["per_base_sem"],
alpha=0.25,
color=color_dict[sample], # Match fill color to line color
)
ax.set_title("Mean Depth per Chromosome (with SEM)")
ax.set_xlabel("Chromosome")
ax.set_ylabel("Mean Depth")
locs, labels = plt.xticks()
ax.set_xticks([loc + 0.01 for loc in locs])
ax.set_xticklabels(labels, rotation=45, ha="right")
ax.grid(axis="y", linestyle="--", alpha=0.7)
if gs is None:
ax.legend(bbox_to_anchor=(1, 1), loc="upper left")
plt.tight_layout()
return fig
else:
ax.legend(bbox_to_anchor=(1.05, 1.05), loc="upper left")
return None
except Exception as e:
logger.error(f"Error plotting mean depth per chromosome: {str(e)}")
raise
basecall_suffix = "_sup"
mosdepth_summary_dir = Path("/scratch/prj/ppn_als_longread/ont-benchmark/qc/mosdepth/")
mosdepth_summary_files = list(
mosdepth_summary_dir.glob(
f"*{basecall_suffix}/*{basecall_suffix}.mosdepth.summary.txt"
)
)
mosdepth_per_base_files = list(
mosdepth_summary_dir.glob(f"*{basecall_suffix}/*{basecall_suffix}.per-base.bed.gz")
)
wg_depth_df, total_depth_df = analyze_mosdepth_data(
metrics_df=nanoplot_qc_metrics_df,
summary_files=mosdepth_summary_files,
per_base_files=mosdepth_per_base_files,
)
mean_depth_chr_plot = plot_mean_depth_per_chromosome(wg_depth_df)
__main__ - INFO - Processing 14 mosdepth summary files
__main__ - INFO - Processing 14 per-base files
__main__ - INFO - Successfully processed depth data for 14 samples
2. Mean Whole Genome Depth¶
In [17]:
def plot_mean_whole_genome_depth(
total_depth_df: pl.DataFrame,
figsize: Tuple[int, int] = (14, 6),
dpi: int = 300,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""Plot mean whole genome depth per sample.
Args:
total_depth_df: DataFrame containing total depth statistics
figsize: Figure size. Defaults to (14, 6)
dpi: Figure DPI. Defaults to 300
gs: GridSpec for plotting within a larger figure
Returns:
Optional[plt.Figure]: Figure object if created independently
Raises:
ValueError: If required columns are missing
"""
try:
required_cols = {"anonymised_sample", "mean_depth", "multiplexing"}
if not all(col in total_depth_df.columns for col in required_cols):
missing = required_cols - set(total_depth_df.columns)
raise ValueError(f"Missing required columns: {missing}")
if gs is None:
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
ax = fig.add_subplot(gs[0, 0])
plot_df = total_depth_df.with_columns(
pl.col("anonymised_sample")
.str.extract(r"(\d+)")
.cast(pl.Int32)
.alias("sample_num")
).sort("sample_num")
sns.barplot(
data=plot_df,
x="anonymised_sample",
y="mean_depth",
hue="multiplexing",
dodge=False,
order=plot_df["anonymised_sample"].to_list(),
ax=ax,
)
ax.set_title("Mean Whole Genome Depth per Sample")
ax.set_xlabel("Sample")
ax.set_ylabel("Depth")
locs, labels = plt.xticks()
ax.set_xticks([loc + 0.15 for loc in locs])
ax.set_xticklabels(labels, rotation=45, ha="right")
if gs is None:
legend = ax.legend(
bbox_to_anchor=(1, 1), loc="upper left", title="Multiplexing"
)
legend.get_title().set_weight("bold")
plt.tight_layout()
return fig
else:
legend = ax.legend(loc="upper left", title="Multiplexing")
legend.get_title().set_weight("bold")
return None
except Exception as e:
logger.error(f"Error plotting mean whole genome depth: {str(e)}")
raise
mean_depth_wg_plot = plot_mean_whole_genome_depth(total_depth_df)
In [18]:
@dataclass
class DepthMetrics:
"""
Data class for storing sequencing depth metrics statistics.
"""
max: float
min: float
mean: float
std: float
median: float
def _calculate_depth_stats(df: pl.DataFrame, column_name: str) -> DepthMetrics:
"""
Calculate depth statistics from a Polars DataFrame.
Args:
df (pl.DataFrame): Input DataFrame containing depth metrics
column_name (str): Name of the column containing depth values
Returns:
DepthMetrics: Object containing depth statistics
Raises:
Exception: If there's an error calculating statistics from the DataFrame
KeyError: If required column is missing from the DataFrame
"""
try:
return DepthMetrics(
max=df.select(pl.col(column_name).max()).item(),
min=df.select(pl.col(column_name).min()).item(),
mean=df.select(pl.col(column_name).mean()).item(),
std=df.select(pl.col(column_name).std()).item(),
median=df.select(pl.col(column_name).median()).item(),
)
except Exception as e:
logger.error(f"Error calculating depth statistics: {str(e)}")
raise
def _print_depth_stats(
stats: DepthMetrics, sample_type: str, depth_type: str = "per-chromosome"
) -> None:
"""
Print formatted depth statistics.
Args:
stats (DepthMetrics): DepthMetrics object containing statistics to print
sample_type (str): Type of sample (Multiplexed/Singleplexed)
depth_type (str): Type of depth calculation ("per-chromosome" or "whole-genome")
Raises:
Exception: If there's an error formatting or printing the statistics
"""
try:
logger.info(f"Printing {depth_type} depth statistics for {sample_type} samples")
print(f"\n{sample_type} Samples Statistics:")
print("=" * 40)
print("\nDepth:")
for stat_name, value in vars(stats).items():
formatted_value = f"{value:.2f}"
print(f" {stat_name.capitalize():6s}: {formatted_value}")
except Exception as e:
logger.error(f"Error printing depth statistics: {str(e)}")
raise
def analyze_sequencing_depth(
wg_depth_df: pl.DataFrame, total_depth_df: pl.DataFrame
) -> None:
"""
Analyze and print sequencing depth statistics for multiplexed and singleplexed samples.
Args:
wg_depth_df (pl.DataFrame): DataFrame containing per-chromosome depth metrics
total_depth_df (pl.DataFrame): DataFrame containing whole genome depth metrics
Raises:
Exception: If there's an error during analysis
ValueError: If required data is missing from the DataFrame
"""
try:
# Per-chromosome depth analysis
singleplexed_depth = wg_depth_df.filter(pl.col("multiplexing") == "singleplex")
multiplexed_depth = wg_depth_df.filter(pl.col("multiplexing") == "multiplex")
if singleplexed_depth.height == 0 or multiplexed_depth.height == 0:
logger.warning("No data found for either singleplex or multiplex samples")
return
print("\nPer-Chromosome Depth Statistics:")
singleplex_depth_stats = _calculate_depth_stats(singleplexed_depth, "mean")
multiplex_depth_stats = _calculate_depth_stats(multiplexed_depth, "mean")
_print_depth_stats(singleplex_depth_stats, "Singleplexed", "per-chromosome")
_print_depth_stats(multiplex_depth_stats, "Multiplexed", "per-chromosome")
# Whole genome depth analysis
singleplexed_wg = total_depth_df.filter(pl.col("multiplexing") == "singleplex")
multiplexed_wg = total_depth_df.filter(pl.col("multiplexing") == "multiplex")
print("\nWhole Genome Depth Statistics:")
singleplex_wg_stats = _calculate_depth_stats(singleplexed_wg, "mean_depth")
multiplex_wg_stats = _calculate_depth_stats(multiplexed_wg, "mean_depth")
_print_depth_stats(singleplex_wg_stats, "Singleplexed", "whole-genome")
_print_depth_stats(multiplex_wg_stats, "Multiplexed", "whole-genome")
print("\nPercentage Increase (Singleplexed vs Multiplexed):")
print("=" * 40)
for stat_name in ["mean", "median"]:
wg_increase = _calculate_percentage_increase(
getattr(singleplex_wg_stats, stat_name),
getattr(multiplex_wg_stats, stat_name),
)
print(f"{stat_name.capitalize():6s} Depth: {wg_increase:6.2f}%")
except Exception as e:
logger.error(f"Error in depth analysis: {str(e)}")
raise
analyze_sequencing_depth(wg_depth_df, total_depth_df)
__main__ - INFO - Printing per-chromosome depth statistics for Singleplexed samples
__main__ - INFO - Printing per-chromosome depth statistics for Multiplexed samples
__main__ - INFO - Printing whole-genome depth statistics for Singleplexed samples
__main__ - INFO - Printing whole-genome depth statistics for Multiplexed samples
Per-Chromosome Depth Statistics: Singleplexed Samples Statistics: ======================================== Depth: Max : 31.15 Min : 1.23 Mean : 19.01 Std : 5.98 Median: 17.45 Multiplexed Samples Statistics: ======================================== Depth: Max : 13.54 Min : 0.44 Mean : 9.45 Std : 2.33 Median: 9.79 Whole Genome Depth Statistics: Singleplexed Samples Statistics: ======================================== Depth: Max : 27.80 Min : 14.42 Mean : 19.84 Std : 5.22 Median: 17.20 Multiplexed Samples Statistics: ======================================== Depth: Max : 12.11 Min : 7.86 Mean : 9.84 Std : 1.58 Median: 9.93 Percentage Increase (Singleplexed vs Multiplexed): ======================================== Mean Depth: 101.63% Median Depth: 73.30%
3. Flowcell Quality¶
In [19]:
def read_flowcell_stats(file_path: Path) -> pl.DataFrame:
"""
Read and process flowcell statistics from CSV file.
Args:
file_path: Path to the CSV file containing flowcell statistics
Returns:
pl.DataFrame: DataFrame containing processed flowcell statistics
Raises:
FileNotFoundError: If input file doesn't exist
ValueError: If required columns are missing
"""
try:
if not file_path.exists():
raise FileNotFoundError(f"Stats file not found: {file_path}")
df = pl.read_csv(file_path)
required_cols = {"flowcell_id", "number_pores_start"}
if not all(col in df.columns for col in required_cols):
missing = required_cols - set(df.columns)
raise ValueError(f"Missing required columns: {missing}")
df = df.with_columns(
[
pl.col("flowcell_id")
.cast(pl.Utf8)
.str.contains("__")
.map_elements(
lambda x: "multiplex" if x else "singleplex", return_dtype=pl.Utf8
)
.alias("multiplexing")
]
)
# Create anonymous flowcell names
multiplex_df = (
df.filter(pl.col("multiplexing") == "multiplex")
.with_row_index("index")
.with_columns(
pl.col("index")
.add(1)
.map_elements(lambda x: f"Multiplex Flowcell {x}", return_dtype=pl.Utf8)
.alias("new_flowcell_name")
)
)
singleplex_df = (
df.filter(pl.col("multiplexing") == "singleplex")
.with_row_index("index")
.with_columns(
pl.col("index")
.add(1)
.map_elements(
lambda x: f"Singleplex Flowcell {x}", return_dtype=pl.Utf8
)
.alias("new_flowcell_name")
)
)
return pl.concat([multiplex_df, singleplex_df]).sort(
["multiplexing", "flowcell_id"]
)
except Exception as e:
logger.error(f"Error reading flowcell stats: {str(e)}")
raise
def plot_flowcell_pores(
df: pl.DataFrame,
figsize: Tuple[int, int] = (14, 6),
dpi: int = 300,
marker_size: int = 8,
line_alpha: float = 0.8,
gs: Optional[gridspec.GridSpec] = None,
) -> plt.Figure:
"""
Plot the number of pores available at start across flowcells.
Args:
df (pl.DataFrame): DataFrame containing flowcell statistics
figsize (Tuple[int, int], optional): Figure size. Defaults to (14, 6).
dpi (int, optional): Figure DPI. Defaults to 300.
marker_size (int, optional): Size of markers. Defaults to 8.
line_alpha (float, optional): Line transparency. Defaults to 0.8.
Returns:
plt.Figure: Matplotlib figure object
Raises:
ValueError: If required columns are missing
"""
try:
required_cols = {"new_flowcell_name", "number_pores_start", "multiplexing"}
if not all(col in df.columns for col in required_cols):
missing = required_cols - set(df.columns)
raise ValueError(f"Missing required columns: {missing}")
if gs is None:
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
ax = fig.add_subplot(gs[0, 0])
sns.lineplot(
data=df,
x="new_flowcell_name",
y="number_pores_start",
hue="multiplexing",
marker="o",
style="multiplexing",
alpha=line_alpha,
markersize=marker_size,
ax=ax,
)
ax.set_title("Number of Pores Available at Start of Sequencing")
ax.set_xlabel("Flowcell ID")
ax.set_ylabel("Number of Pores")
plt.xticks(rotation=45, ha="right")
ax.set_ylim(bottom=0)
if gs is None:
legend = ax.legend(
bbox_to_anchor=(1, 1), loc="upper left", title="Multiplexing"
)
legend.get_title().set_weight("bold")
plt.tight_layout()
return fig
else:
legend = ax.legend(loc="lower right", title="Multiplexing")
legend.get_title().set_weight("bold")
return None
except Exception as e:
logger.error(f"Error plotting flowcell pores: {str(e)}")
raise
seq_stats_path = Path("/scratch/prj/ppn_als_longread/ont-benchmark/seq_stats.csv")
flowcell_stats_df = read_flowcell_stats(seq_stats_path)
flowcell_plot = plot_flowcell_pores(flowcell_stats_df)
In [20]:
@dataclass
class PoreMetrics:
"""
Data class for storing flowcell pore statistics.
"""
max: float
min: float
mean: float
std: float
median: float
def _calculate_pore_stats(df: pl.DataFrame, multiplexing_type: str) -> PoreMetrics:
"""
Calculate pore statistics from a Polars DataFrame.
Args:
df: Input DataFrame containing pore metrics
multiplexing_type: Type of multiplexing to filter by
Returns:
PoreMetrics: Object containing pore statistics
Raises:
ValueError: If required columns are missing
Exception: If there's an error calculating statistics
"""
try:
required_cols = {"multiplexing", "number_pores_start"}
if not all(col in df.columns for col in required_cols):
missing = required_cols - set(df.columns)
raise ValueError(f"Missing required columns: {missing}")
subset = df.filter(pl.col("multiplexing") == multiplexing_type)
return PoreMetrics(
max=subset.select(pl.col("number_pores_start").max()).item(),
min=subset.select(pl.col("number_pores_start").min()).item(),
mean=subset.select(pl.col("number_pores_start").mean()).item(),
std=subset.select(pl.col("number_pores_start").std()).item(),
median=subset.select(pl.col("number_pores_start").median()).item(),
)
except Exception as e:
logger.error(f"Error calculating pore statistics: {str(e)}")
raise
def _print_pore_stats(stats: PoreMetrics, sample_type: str) -> None:
"""
Print formatted pore statistics.
Args:
stats: PoreMetrics object containing statistics to print
sample_type: Type of sample (Multiplexed/Singleplexed)
Raises:
Exception: If there's an error formatting or printing the statistics
"""
try:
logger.info(f"Printing pore statistics for {sample_type} flowcells")
print(f"\n{sample_type} Flowcells Statistics:")
print("=" * 40)
print("\nNumber of Pores at Start:")
for stat_name, value in vars(stats).items():
formatted_value = f"{value:.2f}"
print(f" {stat_name.capitalize():6s}: {formatted_value}")
except Exception as e:
logger.error(f"Error printing pore statistics: {str(e)}")
raise
def analyze_flowcell_pores(df: pl.DataFrame) -> None:
"""
Analyze and print flowcell pore statistics for multiplexed and singleplexed samples.
Args:
df: DataFrame containing flowcell pore metrics
Raises:
Exception: If there's an error during analysis
ValueError: If required data is missing from the DataFrame
"""
try:
singleplex_stats = _calculate_pore_stats(df, "singleplex")
multiplex_stats = _calculate_pore_stats(df, "multiplex")
_print_pore_stats(singleplex_stats, "Singleplexed")
_print_pore_stats(multiplex_stats, "Multiplexed")
print("\nPercentage Increase (Singleplexed vs Multiplexed):")
print("=" * 40)
for stat_name in ["mean", "median"]:
increase = _calculate_percentage_increase(
getattr(singleplex_stats, stat_name),
getattr(multiplex_stats, stat_name),
)
print(f"{stat_name.capitalize():6s} Number of Pores: {increase:6.2f}%")
except Exception as e:
logger.error(f"Error in pore analysis: {str(e)}")
raise
analyze_flowcell_pores(flowcell_stats_df)
__main__ - INFO - Printing pore statistics for Singleplexed flowcells
__main__ - INFO - Printing pore statistics for Multiplexed flowcells
Singleplexed Flowcells Statistics: ======================================== Number of Pores at Start: Max : 8152.00 Min : 4874.00 Mean : 7008.75 Std : 1190.38 Median: 7422.50 Multiplexed Flowcells Statistics: ======================================== Number of Pores at Start: Max : 8223.00 Min : 8024.00 Mean : 8116.00 Std : 100.34 Median: 8101.00 Percentage Increase (Singleplexed vs Multiplexed): ======================================== Mean Number of Pores: -13.64% Median Number of Pores: -8.38%
4. Relation between Flowcell Quality and Mean Whole Genome Depth¶
In [21]:
def parse_seq_stats_data(
seq_stats_df: pl.DataFrame, depth_df: pl.DataFrame
) -> pl.DataFrame:
"""
Parse sequencing statistics data and merge with depth information.
Args:
seq_stats_df: Polars DataFrame containing sequencing statistics
depth_df: Polars DataFrame containing depth information
Returns:
pl.DataFrame: Merged DataFrame with correctly summed depths for multiplexed samples
Raises:
ValueError: If required columns are missing
"""
required_seq_cols = {"flowcell_id", "multiplexing"}
required_depth_cols = {"sample", "mean_depth"}
if not all(col in seq_stats_df.columns for col in required_seq_cols):
missing = required_seq_cols - set(seq_stats_df.columns)
raise ValueError(f"Missing required columns in seq_stats_df: {missing}")
if not all(col in depth_df.columns for col in required_depth_cols):
missing = required_depth_cols - set(depth_df.columns)
raise ValueError(f"Missing required columns in depth_df: {missing}")
# Create a mapping DataFrame for multiplexed samples
multiplexed_samples = (
seq_stats_df.filter(pl.col("multiplexing") == "multiplex")
.select("flowcell_id")
.with_columns([pl.col("flowcell_id").str.split("__").alias("sample_ids")])
.explode("sample_ids")
)
# Create mapping for singleplex samples
singleplex_samples = (
seq_stats_df.filter(pl.col("multiplexing") == "singleplex")
.select("flowcell_id")
.with_columns([pl.col("flowcell_id").alias("sample_ids")])
)
# Combine mappings
all_samples = pl.concat([multiplexed_samples, singleplex_samples])
# Join with depth information and aggregate
merged_depths = (
all_samples.join(depth_df, left_on="sample_ids", right_on="sample", how="left")
.group_by("flowcell_id")
.agg([pl.col("mean_depth").sum().alias("total_mean_depth")])
)
# Join back to original stats DataFrame
return seq_stats_df.join(merged_depths, on="flowcell_id", how="left")
def plot_flowcell_depth_correlation(
df: pl.DataFrame,
figsize: Tuple[int, int] = (12, 8),
dpi: int = 300,
marker_size: int = 100,
confidence_alpha: float = 0.2,
gs: Optional[gridspec.GridSpec] = None,
) -> Tuple[plt.Figure, Tuple[float, float, float, float]]:
"""
Plot correlation between number of pores and sequencing depth.
Args:
df: DataFrame containing flowcell statistics and depth data
figsize: Figure size (width, height)
dpi: Figure resolution
marker_size: Size of scatter plot markers
confidence_alpha: Transparency of confidence interval
gs: Optional GridSpec for subplot placement
Returns:
Tuple containing:
- plt.Figure: Matplotlib figure object
- Tuple[float, float, float, float]: (slope, intercept, r_value, p_value)
Raises:
ValueError: If required columns are missing
"""
required_cols = {"number_pores_start", "total_mean_depth", "multiplexing"}
if not all(col in df.columns for col in required_cols):
missing = required_cols - set(df.columns)
raise ValueError(f"Missing required columns: {missing}")
try:
if gs is None:
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
ax = fig.add_subplot(gs[0, 0])
# Create scatter plot
sns.scatterplot(
data=df,
x="number_pores_start",
y="total_mean_depth",
hue="multiplexing",
s=marker_size,
ax=ax,
)
# Calculate regression
x = df["number_pores_start"]
y = df["total_mean_depth"]
slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
# Plot regression line and confidence interval
x_line = np.linspace(x.min(), x.max(), 100)
y_line = slope * x_line + intercept
ax.plot(
x_line,
y_line,
color="gray",
linestyle="-",
linewidth=2,
label="line of best fit",
)
# Calculate confidence interval
n = len(x)
y_pred = slope * x + intercept
s_err = np.sqrt((y - y_pred).pow(2).sum() / (n - 2))
t = stats.t.ppf(0.975, n - 2)
ci = (
t
* s_err
* np.sqrt(1 / n + (x_line - x.mean()) ** 2 / ((x - x.mean()) ** 2).sum())
)
ax.fill_between(
x_line,
y_line - ci,
y_line + ci,
color="gray",
alpha=confidence_alpha,
label="95% Confidence Interval",
)
ax.set_title(
f"Number of Pores at Start vs Total Mean Whole Genome Depth\n"
f"r = {r_value:.2f}, p = {p_value:.2g}"
)
ax.set_xlabel("Number of Pores at Start")
ax.set_ylabel("Mean Whole Genome Depth")
if gs is None:
legend = ax.legend(
bbox_to_anchor=(1, 1), loc="upper left", title="Multiplexing"
)
legend.get_title().set_weight("bold")
plt.tight_layout()
return fig, (slope, intercept, r_value, p_value)
else:
legend = ax.legend(loc="lower right", title="Multiplexing")
legend.get_title().set_weight("bold")
return None
except Exception as e:
logger.error(f"Error plotting flowcell depth correlation: {str(e)}")
raise
merged_stats_df = parse_seq_stats_data(flowcell_stats_df, total_depth_df)
fig, pores_depth_regression_results = plot_flowcell_depth_correlation(merged_stats_df)
slope, intercept, r_value, p_value = pores_depth_regression_results
logger.info(
f"Regression statistics:\n"
f"Slope: {slope:.4f}\n"
f"Intercept: {intercept:.4f}\n"
f"R-value: {r_value:.4f}\n"
f"P-value: {p_value:.4e}"
)
__main__ - INFO - Regression statistics: Slope: 0.0025 Intercept: 1.6324 R-value: 0.6387 P-value: 3.4417e-02
5. Barcoding Quality¶
In [22]:
def _get_sample_barcode_mapping() -> Dict[str, str]:
"""
Get mapping between sample IDs and barcodes.
Returns:
Dict[str, str]: Mapping of sample IDs to barcodes
"""
return {
"A046_12": "barcode01",
"A079_07": "barcode02",
"A081_91": "barcode03",
"A048_09": "barcode04",
"A097_92": "barcode05",
"A085_00": "barcode06",
}
def _parse_nanostats_barcoded(file_path: Path) -> Dict[str, int]:
"""
Parse NanoStats barcoded file.
Args:
file_path (Path): Path to NanoStats barcoded file
Returns:
Dict[str, int]: Mapping of barcodes to read counts
"""
metrics = {}
with open(file_path, "r") as f:
header = f.readline().strip().split("\t")
values = f.readline().strip().split("\t")
for barcode, value in zip(header[1:], values[1:]):
if barcode == "unclassified" or barcode.startswith("barcode"):
metrics[barcode] = int(value)
return metrics
def _parse_flowcell_samples(seq_summaries_dir: Path) -> Dict[str, List[str]]:
"""
Parse flowcell samples from directory names.
Args:
seq_summaries_dir (Path): Directory containing sequencing summaries
Returns:
Dict[str, List[str]]: Mapping of flowcell names to sample lists
"""
flowcell_samples = {}
for subdir in Path(seq_summaries_dir).iterdir():
if "__" in subdir.name:
samples = subdir.name.split("__")
flowcell_samples[subdir.name] = samples
return flowcell_samples
def plot_multiplexed_flowcell_reads(
seq_summaries_dir: Path,
figsize: Tuple[int, int] = (12, 6),
dpi: int = 300,
bar_width: float = 0.25,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Plot multiplexed flowcell reads distribution using Polars.
Args:
seq_summaries_dir (Path): Directory containing sequencing summaries
figsize (Tuple[int, int]): Figure size
dpi (int): Figure DPI
bar_width (float): Width of bars in plot
color_palette (str): Seaborn color palette name
gs (gridspec.GridSpec, optional): GridSpec for plotting within a larger figure
Returns:
Optional[plt.Figure]: Figure object if created independently, None if using gridspec
"""
try:
sample_barcode_mapping = _get_sample_barcode_mapping()
barcode_sample_mapping = {v: k for k, v in sample_barcode_mapping.items()}
flowcell_samples = _parse_flowcell_samples(seq_summaries_dir)
flowcell_rename = {
name: f"Multiplex Flowcell {i+1}"
for i, name in enumerate(flowcell_samples.keys())
}
unique_samples = sorted(set(sum(flowcell_samples.values(), [])))
sample_rename = {
sample: f"Sample {i+1 if i != 1 and i != 2 else 3 if i == 1 else 2}"
for i, sample in enumerate(unique_samples)
}
sample_rename["Unclassified"] = "Unclassified"
data = []
for subdir in Path(seq_summaries_dir).iterdir():
if "__" not in subdir.name:
continue
flowcell_name = subdir.name
nanostats_path = subdir / "NanoStats_barcoded.txt"
if not nanostats_path.exists():
continue
metrics = _parse_nanostats_barcoded(nanostats_path)
for barcode, read_count in metrics.items():
if barcode in barcode_sample_mapping:
sample = barcode_sample_mapping[barcode]
if sample in flowcell_samples[flowcell_name]:
data.append(
{
"Flowcell": flowcell_rename[flowcell_name],
"Sample": sample_rename[sample],
"Read Count": read_count,
}
)
elif barcode == "unclassified":
data.append(
{
"Flowcell": flowcell_rename[flowcell_name],
"Sample": "Unclassified",
"Read Count": read_count,
}
)
if not data:
raise ValueError("No valid data found for plotting")
df = pl.DataFrame(data)
df = df.with_columns(
[
pl.col("Flowcell").cast(pl.Categorical),
pl.col("Sample").cast(pl.Categorical),
]
)
df = df.sort(["Flowcell", "Sample"])
flowcell_stats = (
df.group_by("Flowcell")
.agg(
[
pl.col("Read Count").mean().alias("mean_reads"),
pl.col("Read Count").std().alias("std_reads"),
]
)
.with_columns(
(pl.col("std_reads") / pl.col("mean_reads") * 100).alias(
"cv_percentage"
)
)
)
for row in flowcell_stats.iter_rows(named=True):
print(
f"{row['Flowcell']} - Coefficient of Variation: {row['cv_percentage']:.2f}%"
)
unclassified_stats = (
df.group_by("Flowcell")
.agg(
pl.col("Read Count").sum().alias("total_reads"),
pl.col("Read Count")
.filter(pl.col("Sample") == "Unclassified")
.alias("unclassified_reads"),
)
.with_columns(
(pl.col("unclassified_reads") / pl.col("total_reads") * 100).alias(
"unclassified_percentage"
)
)
)
mean_unclassified = (
unclassified_stats["unclassified_percentage"].explode().mean()
)
std_unclassified = unclassified_stats["unclassified_percentage"].explode().std()
print(f"Unclassified reads: {mean_unclassified:.2f}% ± {std_unclassified:.2f}%")
if gs is None:
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
ax = fig.add_subplot(gs[0, 0])
color = sns.color_palette()[0]
flowcells = df.select(pl.col("Flowcell").unique()).to_series()
x = range(len(flowcells) * 3)
bars = ax.bar(x, df.select("Read Count").to_series(), bar_width, color=color)
y_max = ax.get_ylim()[1] # Upper limit of y-axis
if y_max > 0:
scale = int(np.floor(np.log10(y_max))) # Compute order of magnitude
if (
scale >= 3
): # Only apply if the scale is meaningful (e.g., thousands or more)
ax.set_ylabel(f"Number of reads ($1×10^{{{scale}}}$)")
else:
ax.set_ylabel("Number of reads")
else:
ax.set_ylabel("Number of reads")
ax.set_title("Number of barcoded reads")
ax.tick_params(
axis="x", which="both", bottom=False, top=False, labelbottom=False
)
ax.xaxis.grid(False)
for i, bar in enumerate(bars):
sample = df.row(i)[1] # Get Sample value
ax.text(
bar.get_x() + bar.get_width() / 2 - 0.3,
-1.8e6,
sample,
ha="center",
va="bottom",
rotation=45,
)
for i, flowcell in enumerate(flowcells):
ax.text(i * 3 + 1, -2.4e6, flowcell, ha="center")
for i in range(1, len(flowcells)):
ax.axvline(x=i * 3 - 0.5, color="gray", linestyle="-", linewidth=0.5)
if gs is None:
plt.tight_layout()
return fig
return None
except Exception as e:
logger.error(f"Error plotting multiplexed flowcell reads: {str(e)}")
raise
multiplex_flowcell_reads_plot = plot_multiplexed_flowcell_reads(np_seq_summaries_dir)
Multiplex Flowcell 1 - Coefficient of Variation: 63.78% Multiplex Flowcell 2 - Coefficient of Variation: 31.77% Multiplex Flowcell 3 - Coefficient of Variation: 38.91% Unclassified reads: 19.04% ± 3.15%
6. Combined Plots¶
In [23]:
def create_combined_sequencing_plot(
wg_depth_df: pl.DataFrame,
total_depth_df: pl.DataFrame,
flowcell_stats_df: pl.DataFrame,
merged_stats_df: pl.DataFrame,
seq_summaries_dir: Path,
figsize: Tuple[int, int] = (12, 16),
dpi: int = 300,
) -> plt.Figure:
try:
# Create figure with GridSpec
fig = plt.figure(figsize=figsize, dpi=dpi)
gs = fig.add_gridspec(4, 2, height_ratios=[1, 1, 1, 1])
# Plot A: Mean Depth per Chromosome (full width)
plot_mean_depth_per_chromosome(
wg_depth_df,
gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[0, :]),
)
# Plot B: Mean Whole Genome Depth per Sample
plot_mean_whole_genome_depth(
total_depth_df,
gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[1, 0]),
)
# Plot C: Number of Pores Available
plot_flowcell_pores(
flowcell_stats_df,
gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[1, 1]),
)
# Plot D: Correlation Plot (capture both figure and statistics)
plot_flowcell_depth_correlation(
merged_stats_df,
gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[2, 0]),
)
# Plot E: Multiplexed Flowcell Reads
plot_multiplexed_flowcell_reads(
seq_summaries_dir,
gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[2, 1]),
)
# Add panel labels
for i, label in enumerate(["A", "B", "C", "D", "E"]):
ax = fig.axes[i]
ax.text(
-0.1,
1.05,
label,
transform=ax.transAxes,
fontsize=12,
fontweight="bold",
va="top",
)
fig.set_constrained_layout(True)
return fig
except Exception as e:
logger.error(f"Error creating combined sequencing plots: {str(e)}")
raise
combined_sequencing_plot = create_combined_sequencing_plot(
wg_depth_df=wg_depth_df,
total_depth_df=total_depth_df,
flowcell_stats_df=flowcell_stats_df,
merged_stats_df=merged_stats_df,
seq_summaries_dir=np_seq_summaries_dir,
)
Multiplex Flowcell 1 - Coefficient of Variation: 63.78% Multiplex Flowcell 2 - Coefficient of Variation: 31.77% Multiplex Flowcell 3 - Coefficient of Variation: 38.91% Unclassified reads: 19.04% ± 3.15%
In [24]:
@dataclass
class SNVRTGMetrics:
"""Data class for RTG vcfeval metrics."""
true_pos_baseline: int
true_pos_call: int
false_pos: int
false_neg: int
precision: float
sensitivity: float
f_measure: float
@dataclass
class SNVAnalysisConfig:
"""Configuration for RTG analysis."""
technologies: tuple[str, ...] = ("ont", "illumina")
complexities: tuple[str, ...] = ("hc", "lc")
metrics_to_test: tuple[str, ...] = ("precision", "sensitivity", "f_measure")
@dataclass
class StatisticalResults:
"""Container for statistical test results."""
t_statistic: float
p_value: float
adjusted_p_value: Optional[float] = None
def _read_rtg_summary(file_path: Path) -> Optional[Dict[str, float]]:
"""
Read and parse RTG vcfeval summary file.
Args:
file_path: Path to the summary file
Returns:
Dictionary containing RTG metrics or None if parsing fails
"""
try:
with open(file_path, "r") as f:
lines = f.readlines()
none_line = next(
(line for line in lines if line.strip().startswith("None")), None
)
if none_line:
values = none_line.split()
return SNVRTGMetrics(
true_pos_baseline=int(values[1]),
true_pos_call=int(values[2]),
false_pos=int(values[3]),
false_neg=int(values[4]),
precision=float(values[5]),
sensitivity=float(values[6]),
f_measure=float(values[7]),
).__dict__
logger.warning(f"No 'None' threshold line found in {file_path}")
return None
except FileNotFoundError:
logger.error(f"File not found: {file_path}")
return None
except Exception as e:
logger.error(f"Error reading file {file_path}: {str(e)}")
return None
def _calculate_rtg_statistics(df: pl.DataFrame) -> pl.DataFrame:
"""
Calculate statistics for RTG metrics grouped by complexity.
Args:
df: Input DataFrame containing RTG metrics
Returns:
DataFrame with calculated statistics
"""
metrics = ["precision", "sensitivity", "f_measure"]
stats_exprs = []
for metric in metrics:
stats_exprs.extend(
[
pl.col(metric).mean().alias(f"{metric}_mean"),
pl.col(metric).std().alias(f"{metric}_std"),
pl.col(metric).median().alias(f"{metric}_median"),
pl.col(metric).min().alias(f"{metric}_min"),
pl.col(metric).max().alias(f"{metric}_max"),
]
)
return df.group_by("complexity").agg(stats_exprs)
def collect_snv_rtg_metrics(
sample_ids: pl.DataFrame, base_path: Path, config: SNVAnalysisConfig
) -> Dict[str, pl.DataFrame]:
"""
Collect SNV metrics from RTG vcfeval summary files.
Args:
sample_ids: DataFrame containing sample IDs
base_path: Base path to the summary files
config: Analysis configuration
Returns:
Dictionary containing DataFrames with metrics for each technology
"""
metrics_data = {tech: [] for tech in config.technologies}
for row in sample_ids.iter_rows(named=True):
for tech in config.technologies:
for complexity in config.complexities:
sample_id = row["ont_id"] if tech == "ont" else row["lp_id"]
summary_file = (
base_path
/ complexity
/ "aggregate"
/ f"{sample_id}.snv"
/ "summary.txt"
)
if summary := _read_rtg_summary(summary_file):
metrics_entry = {
"sample_id": sample_id,
"complexity": complexity,
**summary,
}
metrics_data[tech].append(metrics_entry)
else:
logger.warning(
f"Skipping empty summary for {sample_id}, {tech}, {complexity}"
)
return {
tech: pl.DataFrame(data)
for tech, data in metrics_data.items()
if data # Only include non-empty data
}
def _perform_ttest(
ont_data: pl.DataFrame, illumina_data: pl.DataFrame, metric: str
) -> Tuple[float, float]:
"""
Perform t-test between ONT and Illumina data for a given metric.
Args:
ont_data: DataFrame containing ONT metrics
illumina_data: DataFrame containing Illumina metrics
metric: Name of the metric to test
Returns:
Tuple containing t-statistic and p-value
"""
ont_values = ont_data.get_column(metric).to_numpy()
illumina_values = illumina_data.get_column(metric).to_numpy()
return stats.ttest_ind(ont_values, illumina_values)
def run_rtg_statistical_analysis(
rtg_metrics_dfs: Dict[str, pl.DataFrame], config: SNVAnalysisConfig
) -> Dict[str, pl.DataFrame]:
"""
Run statistical analysis on RTG metrics.
Args:
rtg_metrics_dfs: Dictionary of DataFrames containing metrics for each technology
config: Analysis configuration
Returns:
Dictionary containing statistical results for each technology
"""
results: Dict[str, List[Dict]] = {
complexity: [] for complexity in config.complexities
}
all_p_values = []
for complexity in config.complexities:
ont_data = rtg_metrics_dfs["ont"].filter(pl.col("complexity") == complexity)
illumina_data = rtg_metrics_dfs["illumina"].filter(
pl.col("complexity") == complexity
)
for metric in config.metrics_to_test:
metric_lower = metric.lower()
t_stat, p_value = _perform_ttest(ont_data, illumina_data, metric_lower)
results[complexity].append(
{"metric": metric, "t_statistic": t_stat, "p_value": p_value}
)
all_p_values.append(p_value)
# FDR correction
_, adjusted_p_values = multipletests(all_p_values, method="fdr_bh")[:2]
p_value_idx = 0
for complexity in results:
for result in results[complexity]:
result["adjusted_p_value"] = adjusted_p_values[p_value_idx]
p_value_idx += 1
# Convert results to DataFrames
return {complexity: pl.DataFrame(data) for complexity, data in results.items()}
def display_combined_snv_rtg_statistics(
rtg_metrics_dfs: Dict[str, pl.DataFrame],
statistical_results: Dict[str, pl.DataFrame],
config: SNVAnalysisConfig,
) -> Dict[str, Dict[str, pl.DataFrame]]:
"""
Display combined statistics and statistical test results.
Args:
rtg_metrics_dfs: Dictionary of DataFrames containing metrics
statistical_results: Dictionary of DataFrames containing statistical results
config: Analysis configuration
Returns:
Dictionary containing ONT and Illumina statistics across complexities.
"""
stats_data = {}
print("\n### Sequencing Platform Statistics ###")
for tech in config.technologies:
print(f"\n--- {tech.upper()} ---")
# Calculate and store statistics for both complexities together
combined_stats = _calculate_rtg_statistics(rtg_metrics_dfs[tech])
stats_data[tech] = combined_stats
display(combined_stats)
# Count total true positives, false negatives, and false positives per complexity
variant_counts_df = (
rtg_metrics_dfs[tech]
.group_by("complexity")
.agg(
pl.col("true_pos_baseline").sum().alias("total_true_positives"),
pl.col("false_neg").sum().alias("total_false_negatives"),
pl.col("false_pos").sum().alias("total_false_positives"),
)
.with_columns(
(
pl.col("total_true_positives")
+ pl.col("total_false_negatives")
+ pl.col("total_false_positives")
).alias("total_variants")
)
)
print("Variant counts:")
display(variant_counts_df)
print("\n### Statistical Tests by Complexity ###")
for complexity in config.complexities:
print(f"\nResults for {complexity.upper()} regions:")
display(statistical_results[complexity])
return stats_data
sample_ids = pl.read_csv("sample_ids.csv")
snv_config = SNVAnalysisConfig()
base_path = Path("/scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval")
snv_rtg_metrics_dfs = collect_snv_rtg_metrics(sample_ids, base_path, snv_config)
snv_rtg_statistical_results = run_rtg_statistical_analysis(
snv_rtg_metrics_dfs, snv_config
)
snv_rtg_statistics = display_combined_snv_rtg_statistics(
snv_rtg_metrics_dfs, snv_rtg_statistical_results, snv_config
)
# Create combined DataFrames containing data for all complexities
snv_ont_stats = snv_rtg_statistics["ont"].with_columns(
pl.when(pl.col("complexity").is_null())
.then(pl.lit("unknown"))
.otherwise(pl.col("complexity"))
.alias("complexity")
)
snv_illumina_stats = snv_rtg_statistics["illumina"].with_columns(
pl.when(pl.col("complexity").is_null())
.then(pl.lit("unknown"))
.otherwise(pl.col("complexity"))
.alias("complexity")
)
### Sequencing Platform Statistics ### --- ONT ---
shape: (2, 16)
| complexity | precision_mean | precision_std | precision_median | precision_min | precision_max | sensitivity_mean | sensitivity_std | sensitivity_median | sensitivity_min | sensitivity_max | f_measure_mean | f_measure_std | f_measure_median | f_measure_min | f_measure_max |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| str | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 |
| "lc" | 0.777957 | 0.013331 | 0.78075 | 0.7486 | 0.7936 | 0.734221 | 0.018678 | 0.74115 | 0.6945 | 0.7547 | 0.755429 | 0.015967 | 0.76115 | 0.7205 | 0.7723 |
| "hc" | 0.953079 | 0.009838 | 0.9551 | 0.9286 | 0.9625 | 0.955564 | 0.021229 | 0.9649 | 0.9085 | 0.9749 | 0.954293 | 0.015401 | 0.9605 | 0.9184 | 0.9686 |
Variant counts:
shape: (2, 5)
| complexity | total_true_positives | total_false_negatives | total_false_positives | total_variants |
|---|---|---|---|---|
| str | i64 | i64 | i64 | i64 |
| "lc" | 298502 | 108032 | 85131 | 491665 |
| "hc" | 8720542 | 405066 | 428001 | 9553609 |
--- ILLUMINA ---
shape: (2, 16)
| complexity | precision_mean | precision_std | precision_median | precision_min | precision_max | sensitivity_mean | sensitivity_std | sensitivity_median | sensitivity_min | sensitivity_max | f_measure_mean | f_measure_std | f_measure_median | f_measure_min | f_measure_max |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| str | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 |
| "hc" | 0.962943 | 0.003552 | 0.96495 | 0.9545 | 0.9663 | 0.972471 | 0.000851 | 0.97285 | 0.9706 | 0.9736 | 0.967693 | 0.002114 | 0.9689 | 0.9625 | 0.9696 |
| "lc" | 0.79945 | 0.004808 | 0.8009 | 0.7882 | 0.8046 | 0.7433 | 0.002283 | 0.7432 | 0.7388 | 0.7463 | 0.770336 | 0.003014 | 0.77065 | 0.7627 | 0.7738 |
Variant counts:
shape: (2, 5)
| complexity | total_true_positives | total_false_negatives | total_false_positives | total_variants |
|---|---|---|---|---|
| str | i64 | i64 | i64 | i64 |
| "lc" | 302174 | 104360 | 75798 | 482332 |
| "hc" | 8874480 | 251128 | 341352 | 9466960 |
### Statistical Tests by Complexity ### Results for HC regions:
shape: (3, 4)
| metric | t_statistic | p_value | adjusted_p_value |
|---|---|---|---|
| str | f64 | f64 | f64 |
| "precision" | -3.528853 | 0.001576 | 0.004023 |
| "sensitivity" | -2.977588 | 0.006214 | 0.007457 |
| "f_measure" | -3.225201 | 0.003384 | 0.005076 |
Results for LC regions:
shape: (3, 4)
| metric | t_statistic | p_value | adjusted_p_value |
|---|---|---|---|
| str | f64 | f64 | f64 |
| "precision" | -5.674713 | 0.000006 | 0.000034 |
| "sensitivity" | -1.805218 | 0.082635 | 0.082635 |
| "f_measure" | -3.432703 | 0.002012 | 0.004023 |
In [25]:
def _get_significance_level(p_value: float) -> str:
"""
Determine significance level based on p-value.
Args:
p_value: Statistical p-value
Returns:
str: Significance level indicator
"""
if p_value < 0.001:
return "***"
elif p_value < 0.01:
return "**"
elif p_value < 0.05:
return "*"
return ""
def prepare_snv_performance_data(
ont_stats: pl.DataFrame,
illumina_stats: pl.DataFrame,
stat_results: Dict[str, pl.DataFrame],
metrics: Tuple[str, ...],
complexities: Tuple[str, ...],
) -> pl.DataFrame:
"""
Prepare data for performance visualization.
Args:
ont_stats: ONT statistics DataFrame
illumina_stats: Illumina statistics DataFrame
stat_results: Dictionary containing statistical test results
metrics: Metrics to plot
complexities: Complexity levels to plot
Returns:
pl.DataFrame: Prepared data for plotting
"""
plot_data = []
# Map plotting metric names to the underlying column name keys.
metric_mapping = {
"Precision": "precision",
"Sensitivity": "sensitivity",
"F-measure": "f_measure",
}
for complexity in complexities:
for metric in metrics:
try:
col_name = f"{metric_mapping[metric]}_mean"
# Check if we have data for this complexity
ont_filtered = ont_stats.filter(pl.col("complexity") == complexity)
illumina_filtered = illumina_stats.filter(
pl.col("complexity") == complexity
)
if ont_filtered.height == 0 or illumina_filtered.height == 0:
logger.warning(f"No data available for {complexity} {metric}")
continue
ont_value = ont_filtered.get_column(col_name)[0]
illumina_value = illumina_filtered.get_column(col_name)[0]
# Get adjusted p-value from the statistical results
stat_df = stat_results.get(complexity)
if stat_df is not None:
stat_rows = stat_df.filter(
pl.col("metric") == metric_mapping[metric]
).to_dicts()
adjusted_p_value = (
stat_rows[0]["adjusted_p_value"] if stat_rows else None
)
else:
adjusted_p_value = None
significance = (
_get_significance_level(adjusted_p_value)
if adjusted_p_value is not None
else ""
)
plot_data.extend(
[
{
"Complexity": complexity.upper(),
"Metric": metric,
"Technology": "long-read",
"Value": ont_value,
"Significance": significance,
},
{
"Complexity": complexity.upper(),
"Metric": metric,
"Technology": "short-read",
"Value": illumina_value,
"Significance": significance,
},
]
)
except Exception as e:
logger.error(
f"Error preparing data for {complexity} {metric}: {str(e)}"
)
continue
return pl.DataFrame(plot_data)
def plot_snv_performance_metrics(
plot_data_df: pl.DataFrame,
figsize: Tuple[int, int] = (14, 6),
dpi: int = 300,
ylim: Tuple[float, float] = (0, 1),
title: str = "Performance Comparison",
metrics: Tuple[str, ...] = ("Precision", "Sensitivity", "F-measure"),
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[Tuple[plt.Figure, np.ndarray]]:
"""
Create a performance comparison plot for SNV detection between long-read and short-read technologies.
This function creates a side-by-side comparison of performance metrics for high and low complexity
regions. It automatically generates two subplots (high and low complexity) with bar plots showing
the specified performance metrics for both technologies.
Args:
plot_data_df (pl.DataFrame): Prepared data DataFrame from prepare_snv_performance_data
figsize (Tuple[int, int], optional): Figure size in inches. Defaults to (14, 6)
dpi (int, optional): Figure resolution. Defaults to 300
ylim (Tuple[float, float], optional): Y-axis limits. Defaults to (0, 1.05)
title (str, optional): Plot title. Defaults to "Performance Comparison"
metrics (Tuple[str, ...], optional): Metrics to plot. Defaults to ("Precision", "Sensitivity", "F-measure")
gs (Optional[gridspec.GridSpec], optional): GridSpec for subplot placement. Defaults to None
Returns:
Optional[Tuple[plt.Figure, np.ndarray]]: Figure and axes objects if gs is None, None otherwise
Raises:
Exception: If there's an error in creating the performance plots
"""
try:
if gs is None:
fig, axes = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
axes = [fig.add_subplot(gs[0, i]) for i in range(2)]
complexities = ["HC", "LC"]
for i, complexity in enumerate(complexities):
complexity_data = plot_data_df.filter(pl.col("Complexity") == complexity)
bars = sns.barplot(
data=complexity_data,
x="Metric",
y="Value",
hue="Technology",
errorbar=None,
ax=axes[i],
)
# Add value labels on top of each bar
for p in bars.patches:
value = p.get_height()
if value > 0: # Only annotate if value is not 0
axes[i].annotate(
f"{value:.3f}",
(p.get_x() + p.get_width() / 2.0, value),
ha="center",
va="bottom",
fontsize=8,
rotation=0,
)
complexity_label = "High" if complexity == "HC" else "Low"
axes[i].set_title(
f"SNV Performance in {complexity_label} Complexity Regions", pad=15
)
axes[i].set_ylim(ylim)
# Add significance annotations
for metric_idx, metric in enumerate(metrics):
metric_data = complexity_data.filter(pl.col("Metric") == metric)
if metric_data.height > 0:
significance = metric_data.get_column("Significance")[0]
if significance:
y = metric_data.get_column("Value").max() + 0.02
axes[i].text(
metric_idx,
y,
significance,
ha="center",
va="bottom",
color="black",
fontweight="bold",
)
axes[i].set_xlabel("")
axes[i].set_ylabel("Performance")
if i == 0:
axes[i].legend_.remove()
else:
if gs is None:
legend = axes[i].legend(
title="Technology", bbox_to_anchor=(1, 1), loc="upper left"
)
legend.get_title().set_weight("bold")
else:
legend = axes[i].legend(title="Technology", loc="lower right")
legend.get_title().set_weight("bold")
if gs is None:
plt.tight_layout()
return fig
return None
except Exception as e:
logger.error(f"Error creating performance plots: {str(e)}")
raise
performance_data = prepare_snv_performance_data(
ont_stats=snv_ont_stats,
illumina_stats=snv_illumina_stats,
stat_results=snv_rtg_statistical_results,
metrics=("Precision", "Sensitivity", "F-measure"),
complexities=("hc", "lc"),
)
snv_performance_plot = plot_snv_performance_metrics(
plot_data_df=performance_data,
)
2. Error Analysis¶
In [26]:
def _get_vcfeval_snv_error_paths(
sample_id: str, tech: str, complexity: str, base_dir: Path = base_path
) -> Tuple[Path, Path, Path]:
"""
Generate paths for VCF evaluation files.
Args:
sample_id: Sample identifier
tech: Technology type (ont/illumina)
complexity: Genomic complexity region (hc/lc)
base_dir: Base directory for project data
Returns:
Tuple of Paths for (false positives, false negatives, query) VCF files
"""
vcfeval_dir = base_dir / complexity / "aggregate" / f"{sample_id}.snv"
return (
vcfeval_dir / "fp.vcf.gz",
vcfeval_dir / "fn.vcf.gz",
vcfeval_dir / "tp.vcf.gz",
)
def _count_snv_types(vcf_file: Path) -> Dict[str, int]:
"""
Count different types of SNVs in a VCF file.
Args:
vcf_file: Path to VCF file
Returns:
Dictionary mapping SNV types to their counts
"""
snv_counts: Dict[str, int] = {}
try:
with pysam.VariantFile(str(vcf_file)) as vcf:
for record in vcf:
ref = record.ref
alt = record.alts[0]
if len(ref) == 1 and len(alt) == 1:
snv_type = f"{ref}>{alt}"
snv_counts[snv_type] = snv_counts.get(snv_type, 0) + 1
except Exception as e:
logger.error(f"Error counting SNV types in {vcf_file}: {str(e)}")
return snv_counts
def count_total_variants(vcf_file: Path) -> int:
"""
Count total number of variants in a VCF file.
Args:
vcf_file: Path to VCF file
Returns:
Total number of variants
"""
try:
with pysam.VariantFile(str(vcf_file)) as vcf:
return sum(1 for _ in vcf)
except Exception as e:
logger.error(f"Error counting variants in {vcf_file}: {str(e)}")
return 0
def calculate_snv_error_rates(
sample_ids: pl.DataFrame,
technologies: List[str] = ["ont", "illumina"],
complexities: List[str] = ["hc", "lc"],
base_dir: Path = base_path,
) -> Tuple[
Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]], Dict[str, Dict[str, int]]
]:
"""
Calculate SNV error rates across samples and technologies.
Args:
sample_ids: DataFrame containing sample ID mappings
technologies: List of sequencing technologies
complexities: List of genomic complexity regions
base_dir: Base directory for project data
Returns:
Tuple containing:
- Nested dictionary of error rates
- Dictionary of sample counts per technology and complexity
"""
snv_types = [f"{ref}>{alt}" for ref in "ACGT" for alt in "ACGT" if ref != alt]
snv_error_rates = {
tech: {comp: {"FP": {}, "FN": {}} for comp in complexities}
for tech in technologies
}
sample_counts = {tech: {comp: 0 for comp in complexities} for tech in technologies}
for row in sample_ids.iter_rows(named=True):
for tech in technologies:
sample_id = row["ont_id"] if tech == "ont" else row["lp_id"]
for complexity in complexities:
fp_vcf, fn_vcf, query_vcf = _get_vcfeval_snv_error_paths(
sample_id, tech, complexity, base_dir
)
if not all(path.exists() for path in [fp_vcf, fn_vcf, query_vcf]):
logger.warning(
f"VCF files not found for {sample_id}, {tech}, {complexity}"
)
continue
sample_counts[tech][complexity] += 1
total_variants = count_total_variants(query_vcf)
if total_variants == 0:
logger.warning(
f"No variants found for {sample_id}, {tech}, {complexity}"
)
continue
for error_type, vcf_path in [("FP", fp_vcf), ("FN", fn_vcf)]:
counts = _count_snv_types(vcf_path)
for snv_type in snv_types:
if (
snv_type
not in snv_error_rates[tech][complexity][error_type]
):
snv_error_rates[tech][complexity][error_type][snv_type] = []
error_rate = counts.get(snv_type, 0) / total_variants
snv_error_rates[tech][complexity][error_type][snv_type].append(
error_rate
)
return snv_error_rates, sample_counts
def prepare_snv_error_data(
snv_error_rates: Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]
) -> pl.DataFrame:
"""
Prepare SNV error rate data for visualization.
Args:
snv_error_rates: Nested dictionary containing error rates
Returns:
DataFrame containing processed error rate data
"""
tech_mapping = {"ont": "long-read", "illumina": "short-read"}
plot_data = []
for tech, tech_data in snv_error_rates.items():
display_tech = tech_mapping[tech]
for complexity, comp_data in tech_data.items():
for error_type, error_data in comp_data.items():
for snv_type, rates in error_data.items():
plot_data.append(
{
"Technology": display_tech,
"Complexity": complexity.upper(),
"Error_Type": error_type,
"SNV_Type": snv_type,
"Error_Rate": np.mean(rates) if rates else 0.0,
}
)
return pl.DataFrame(plot_data)
def perform_statistical_tests(
snv_error_rates: Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]],
sample_counts: Dict[str, Dict[str, int]],
) -> pl.DataFrame:
"""
Perform statistical tests comparing error rates between technologies.
Args:
snv_error_rates: Nested dictionary containing error rates
sample_counts: Dictionary of sample counts
Returns:
DataFrame containing statistical test results
"""
results = []
p_values = []
for complexity in ["hc", "lc"]:
for error_type in ["FP", "FN"]:
for snv_type in snv_error_rates["ont"][complexity][error_type]:
long_read_rates = snv_error_rates["ont"][complexity][error_type][
snv_type
]
short_read_rates = snv_error_rates["illumina"][complexity][error_type][
snv_type
]
if not long_read_rates or not short_read_rates:
logger.warning(
f"No data for {complexity}, {error_type}, {snv_type}"
)
continue
n = min(len(long_read_rates), len(short_read_rates))
try:
t_stat, p_val = stats.ttest_rel(
long_read_rates[:n], short_read_rates[:n]
)
except Exception as e:
logger.error(
f"Error in t-test for {complexity}, {error_type}, {snv_type}: {str(e)}"
)
continue
results.append(
{
"Complexity": complexity.upper(),
"Error_Type": error_type,
"SNV_Type": snv_type,
"t_statistic": t_stat,
"p_value": p_val,
"n": n,
}
)
p_values.append(p_val)
if not results:
logger.warning("No statistical test results could be calculated")
return pl.DataFrame()
try:
rejected, p_corrected, _, _ = multipletests(p_values, method="fdr_bh")
except Exception as e:
logger.error(f"Error in multiple testing correction: {str(e)}")
return pl.DataFrame(results)
for i, (result, p_adj, is_rejected) in enumerate(
zip(results, p_corrected, rejected)
):
result.update(
{
"p_value_adjusted": p_adj,
"significance": _get_significance_level(p_adj),
"rejected": is_rejected,
}
)
return pl.DataFrame(results)
def plot_snv_error_rates(
plot_data: pl.DataFrame,
statistical_results: pl.DataFrame,
figsize: Tuple[int, int] = (16, 12),
dpi: int = 300,
gs: Optional[gridspec.GridSpec] = None,
ylim: Optional[Tuple[float, float]] = None,
) -> Optional[Tuple[plt.Figure, np.ndarray]]:
"""
Create plots showing SNV error rates across technologies and complexities.
Args:
plot_data: DataFrame containing error rate data
statistical_results: DataFrame containing statistical test results
figsize: Figure dimensions as (width, height)
dpi: Figure resolution
gs: Optional GridSpec for subplot placement
ylim: Optional Y-axis limits as (min, max)
Returns:
- (fig, axes) if gs is None
- None if using a provided GridSpec
"""
if gs is None:
fig, axes = plt.subplots(2, 2, figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
axes = np.array(
[[fig.add_subplot(gs[i, j]) for j in range(2)] for i in range(2)]
)
try:
complexities = sorted(plot_data["Complexity"].unique().to_list())
error_types = sorted(plot_data["Error_Type"].unique().to_list())
for i, complexity in enumerate(complexities):
for j, error_type in enumerate(error_types):
subset = plot_data.filter(
(pl.col("Complexity") == complexity)
& (pl.col("Error_Type") == error_type)
).to_pandas()
if subset.empty:
logger.warning(f"No data for {complexity}, {error_type}")
continue
sns.barplot(
data=subset,
x="SNV_Type",
y="Error_Rate",
hue="Technology",
ax=axes[i, j],
)
axes[i, j].get_legend().remove()
# Set title based on complexity
title_complexity = "High" if "HC" in complexity else "Low"
axes[i, j].set_title(
f"{error_type} Rates in {title_complexity} Complexity Regions "
)
axes[i, j].set_xlabel("SNV Type")
axes[i, j].set_ylabel("Error Rate (%)")
axes[i, j].yaxis.set_major_formatter(mticker.PercentFormatter(1))
if ylim:
axes[i, j].set_ylim(ylim)
# Add significance annotations
for idx, snv_type in enumerate(subset["SNV_Type"].unique()):
filtered_results = statistical_results.filter(
(pl.col("Complexity") == complexity)
& (pl.col("Error_Type") == error_type)
& (pl.col("SNV_Type") == snv_type)
)
if not filtered_results.is_empty():
max_height = subset[subset["SNV_Type"] == snv_type][
"Error_Rate"
].max()
significance = filtered_results.select(
pl.col("significance")
).item()
axes[i, j].text(
idx,
max_height,
significance,
ha="center",
va="bottom",
fontweight="bold",
)
if (i, j) == (0, 1) and gs is None:
axes[i, j].legend(
title="Technology", bbox_to_anchor=(1, 1), loc="upper left"
)
axes[i, j].get_legend().get_title().set_weight("bold")
except Exception as e:
logger.error(f"Error creating SNV error rate plots: {str(e)}")
raise
if gs is None:
plt.tight_layout()
return fig
return None
snv_error_rates, sample_counts = calculate_snv_error_rates(
sample_ids,
technologies=snv_config.technologies,
complexities=snv_config.complexities,
)
snv_error_plot_data = prepare_snv_error_data(snv_error_rates)
snv_error_statistical_results = perform_statistical_tests(
snv_error_rates, sample_counts
)
snv_error_rate_plot = plot_snv_error_rates(
snv_error_plot_data, snv_error_statistical_results
)
with pl.Config(tbl_rows=len(snv_error_statistical_results)):
display(snv_error_statistical_results)
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008462-DNA_A09.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/A048_09.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008463-DNA_F04.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/A079_07.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008463-DNA_F02.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/A081_91.snv/fp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008463-DNA_H09.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008463-DNA_A07.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008463-DNA_A07.snv/fn.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/lc/aggregate/A149_01.snv/tp.vcf.gz.tbi [W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008462-DNA_C03.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008462-DNA_D03.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008463-DNA_F01.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008462-DNA_C05.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/A160_96.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/A162_09.snv/tp.vcf.gz.tbi
shape: (48, 9)
| Complexity | Error_Type | SNV_Type | t_statistic | p_value | n | p_value_adjusted | significance | rejected |
|---|---|---|---|---|---|---|---|---|
| str | str | str | f64 | f64 | i64 | f64 | str | i64 |
| "HC" | "FP" | "A>C" | 4.796801 | 0.000349 | 14 | 0.00082 | "***" | 1 |
| "HC" | "FP" | "A>G" | 5.04952 | 0.000223 | 14 | 0.000708 | "***" | 1 |
| "HC" | "FP" | "A>T" | 5.01635 | 0.000236 | 14 | 0.000708 | "***" | 1 |
| "HC" | "FP" | "C>A" | 4.667472 | 0.00044 | 14 | 0.000919 | "***" | 1 |
| "HC" | "FP" | "C>G" | 5.417316 | 0.000118 | 14 | 0.00047 | "***" | 1 |
| "HC" | "FP" | "C>T" | 4.897777 | 0.000291 | 14 | 0.000736 | "***" | 1 |
| "HC" | "FP" | "G>A" | 4.78126 | 0.000359 | 14 | 0.00082 | "***" | 1 |
| "HC" | "FP" | "G>C" | 5.169453 | 0.00018 | 14 | 0.000619 | "***" | 1 |
| "HC" | "FP" | "G>T" | 4.753991 | 0.000377 | 14 | 0.000822 | "***" | 1 |
| "HC" | "FP" | "T>A" | 4.296175 | 0.000869 | 14 | 0.001739 | "**" | 1 |
| "HC" | "FP" | "T>C" | 4.947939 | 0.000266 | 14 | 0.000711 | "***" | 1 |
| "HC" | "FP" | "T>G" | 4.94774 | 0.000267 | 14 | 0.000711 | "***" | 1 |
| "HC" | "FN" | "A>C" | 2.907751 | 0.012225 | 14 | 0.016969 | "*" | 1 |
| "HC" | "FN" | "A>G" | 2.944984 | 0.011381 | 14 | 0.016969 | "*" | 1 |
| "HC" | "FN" | "A>T" | 2.755924 | 0.016352 | 14 | 0.020125 | "*" | 1 |
| "HC" | "FN" | "C>A" | 2.927763 | 0.011764 | 14 | 0.016969 | "*" | 1 |
| "HC" | "FN" | "C>G" | 3.321026 | 0.00552 | 14 | 0.009463 | "**" | 1 |
| "HC" | "FN" | "C>T" | 3.08497 | 0.008695 | 14 | 0.014391 | "*" | 1 |
| "HC" | "FN" | "G>A" | 2.901455 | 0.012373 | 14 | 0.016969 | "*" | 1 |
| "HC" | "FN" | "G>C" | 3.324118 | 0.005487 | 14 | 0.009463 | "**" | 1 |
| "HC" | "FN" | "G>T" | 2.811492 | 0.014703 | 14 | 0.019604 | "*" | 1 |
| "HC" | "FN" | "T>A" | 2.772484 | 0.015842 | 14 | 0.020125 | "*" | 1 |
| "HC" | "FN" | "T>C" | 3.003737 | 0.010166 | 14 | 0.016265 | "*" | 1 |
| "HC" | "FN" | "T>G" | 2.966015 | 0.01093 | 14 | 0.016924 | "*" | 1 |
| "LC" | "FP" | "A>C" | 11.534503 | 3.3539e-8 | 14 | 0.000002 | "***" | 1 |
| "LC" | "FP" | "A>G" | 6.409737 | 0.000023 | 14 | 0.000111 | "***" | 1 |
| "LC" | "FP" | "A>T" | 8.259462 | 0.000002 | 14 | 0.000015 | "***" | 1 |
| "LC" | "FP" | "C>A" | 8.883831 | 6.9737e-7 | 14 | 0.00001 | "***" | 1 |
| "LC" | "FP" | "C>G" | 5.241883 | 0.000159 | 14 | 0.000587 | "***" | 1 |
| "LC" | "FP" | "C>T" | 7.908565 | 0.000003 | 14 | 0.000017 | "***" | 1 |
| "LC" | "FP" | "G>A" | 7.326255 | 0.000006 | 14 | 0.000031 | "***" | 1 |
| "LC" | "FP" | "G>C" | 8.778636 | 7.9751e-7 | 14 | 0.00001 | "***" | 1 |
| "LC" | "FP" | "G>T" | 7.906364 | 0.000003 | 14 | 0.000017 | "***" | 1 |
| "LC" | "FP" | "T>A" | 9.238933 | 4.4716e-7 | 14 | 0.00001 | "***" | 1 |
| "LC" | "FP" | "T>C" | 5.687115 | 0.000075 | 14 | 0.000325 | "***" | 1 |
| "LC" | "FP" | "T>G" | 7.502373 | 0.000004 | 14 | 0.000027 | "***" | 1 |
| "LC" | "FN" | "A>C" | 4.056768 | 0.001359 | 14 | 0.002609 | "**" | 1 |
| "LC" | "FN" | "A>G" | 1.183548 | 0.257783 | 14 | 0.268991 | "" | 0 |
| "LC" | "FN" | "A>T" | 1.158238 | 0.267603 | 14 | 0.273297 | "" | 0 |
| "LC" | "FN" | "C>A" | 2.761842 | 0.016168 | 14 | 0.020125 | "*" | 1 |
| "LC" | "FN" | "C>G" | 0.065988 | 0.948392 | 14 | 0.948392 | "" | 0 |
| "LC" | "FN" | "C>T" | 1.749879 | 0.103689 | 14 | 0.115746 | "" | 0 |
| "LC" | "FN" | "G>A" | 1.343779 | 0.202005 | 14 | 0.215472 | "" | 0 |
| "LC" | "FN" | "G>C" | 2.364925 | 0.034259 | 14 | 0.041111 | "*" | 1 |
| "LC" | "FN" | "G>T" | 3.574485 | 0.003394 | 14 | 0.006265 | "**" | 1 |
| "LC" | "FN" | "T>A" | 1.826792 | 0.090773 | 14 | 0.103741 | "" | 0 |
| "LC" | "FN" | "T>C" | 1.669465 | 0.118915 | 14 | 0.129725 | "" | 0 |
| "LC" | "FN" | "T>G" | 1.857803 | 0.085987 | 14 | 0.100668 | "" | 0 |
3. Combined Plots¶
In [27]:
def create_combined_snv_metrics_plot(
ont_stats: pl.DataFrame,
illumina_stats: pl.DataFrame,
stat_results: Dict[str, pl.DataFrame],
error_plot_data: pl.DataFrame,
error_statistical_results: pl.DataFrame,
figsize: Tuple[int, int] = (12, 10),
dpi: int = 300,
) -> plt.Figure:
"""
Create a combined figure showing SNV performance metrics and error rates.
Args:
ont_stats: ONT statistics DataFrame
illumina_stats: Illumina statistics DataFrame
stat_results: Dictionary containing statistical test results
error_plot_data: DataFrame containing error rate data
error_statistical_results: DataFrame containing error rate statistical results
figsize: Figure dimensions as (width, height)
dpi: Figure resolution
Returns:
Combined figure object containing all plots
Raises:
Exception: If there's an error creating the combined plot
"""
try:
# Create figure with GridSpec
fig = plt.figure(figsize=figsize, dpi=dpi)
gs = fig.add_gridspec(3, 2, height_ratios=[1, 1, 1])
# Prepare data for performance metrics plot
performance_data = prepare_snv_performance_data(
ont_stats=ont_stats,
illumina_stats=illumina_stats,
stat_results=stat_results,
metrics=("Precision", "Sensitivity", "F-measure"),
complexities=("hc", "lc"),
)
# Plot A & B: SNV Performance Metrics
plot_snv_performance_metrics(
plot_data_df=performance_data,
gs=gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs[0, :]),
)
# Plot C, D, E, F: Error Rates
plot_snv_error_rates(
plot_data=error_plot_data,
statistical_results=error_statistical_results,
gs=gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[1:, :]),
)
# Add panel labels
for i, ax in enumerate(fig.axes):
label = chr(ord("A") + i)
ax.text(
-0.12,
1.05,
label,
transform=ax.transAxes,
fontsize=12,
fontweight="bold",
va="top",
)
fig.set_constrained_layout(True)
return fig
except Exception as e:
logger.error(f"Error creating combined SNV metrics plot: {str(e)}")
raise
combined_snv_metrics_plot = create_combined_snv_metrics_plot(
ont_stats=snv_ont_stats,
illumina_stats=snv_illumina_stats,
stat_results=snv_rtg_statistical_results,
error_plot_data=snv_error_plot_data,
error_statistical_results=snv_error_statistical_results,
)
In [28]:
@dataclass
class IndelRTGMetrics:
"""Data class for indel RTG vcfeval metrics."""
true_pos_baseline: int
true_pos_call: int
false_pos: int
false_neg: int
precision: float
sensitivity: float
f_measure: float
@dataclass
class IndelAnalysisConfig:
"""Configuration for indel analysis."""
technologies: tuple[str, ...] = ("ont",)
complexities: tuple[str, ...] = ("hc", "lc")
metrics_to_test: tuple[str, ...] = ("precision", "sensitivity", "f_measure")
base_dir: Path = Path("/scratch/prj/ppn_als_longread/ont-benchmark")
subdir_mapping: dict[str, str] = field(
default_factory=lambda: {"hc": "aggregate", "lc": "1bp"}
)
def _get_vcfeval_indel_paths(
sample_id: str, complexity: str, config: IndelAnalysisConfig
) -> tuple[Path, Path, Path]:
"""
Generate paths for indel VCF evaluation files.
Args:
sample_id: Sample identifier
complexity: Genomic complexity region
config: Analysis configuration
Returns:
Tuple of Paths for (false positives, false negatives, true positives) VCF files
"""
subdir = config.subdir_mapping.get(complexity, "")
vcfeval_dir = (
config.base_dir
/ "output"
/ "indel"
/ "rtg_vcfeval"
/ complexity
/ subdir
/ f"{sample_id}.indel"
)
return (
vcfeval_dir / "fp.vcf.gz",
vcfeval_dir / "fn.vcf.gz",
vcfeval_dir / "tp.vcf.gz",
)
def collect_indel_rtg_metrics(
sample_ids: pl.DataFrame, config: IndelAnalysisConfig
) -> Dict[str, pl.DataFrame]:
"""
Collect indel metrics from RTG vcfeval summary files.
Args:
sample_ids: DataFrame containing sample IDs
config: Analysis configuration
Returns:
Dictionary containing DataFrames with metrics for each technology
"""
metrics_data = {tech: [] for tech in config.technologies}
for row in sample_ids.iter_rows(named=True):
for tech in config.technologies:
for complexity in config.complexities:
sample_id = row["ont_id"]
fp_vcf, fn_vcf, tp_vcf = _get_vcfeval_indel_paths(
sample_id, complexity, config
)
if not all(path.exists() for path in [fp_vcf, fn_vcf, tp_vcf]):
logger.warning(f"VCF files not found for {sample_id}, {complexity}")
continue
if summary := _read_rtg_summary(fp_vcf.parent / "summary.txt"):
metrics_entry = {
"sample_id": sample_id,
"complexity": complexity,
**summary,
}
metrics_data[tech].append(metrics_entry)
else:
logger.warning(
f"Skipping empty summary for {sample_id}, {tech}, {complexity}"
)
return {tech: pl.DataFrame(data) for tech, data in metrics_data.items() if data}
def display_indel_statistics(
rtg_metrics_dfs: Dict[str, pl.DataFrame],
config: IndelAnalysisConfig,
) -> Dict[str, Dict[str, pl.DataFrame]]:
"""
Process and compile statistics for indel analysis.
Args:
rtg_metrics_dfs: Dictionary of DataFrames containing metrics
config: Analysis configuration
Returns:
Dictionary containing ONT statistics for each complexity
"""
stats_data = {}
for complexity in config.complexities:
ont_stats = _calculate_rtg_statistics(
rtg_metrics_dfs["ont"].filter(pl.col("complexity") == complexity)
)
stats_data[complexity] = {"ont": ont_stats}
return stats_data
indel_config = IndelAnalysisConfig()
indel_rtg_metrics_dfs = collect_indel_rtg_metrics(sample_ids, indel_config)
indel_statistics = display_indel_statistics(indel_rtg_metrics_dfs, indel_config)
# Create combined DataFrame containing data for all complexities
indel_ont_stats = pl.concat(
[
indel_statistics[complexity]["ont"].with_columns(
pl.lit(complexity).alias("complexity")
)
for complexity in indel_config.complexities
]
)
print("ONT Indel Statistics:")
display(indel_ont_stats)
ONT Indel Statistics:
shape: (2, 16)
| complexity | precision_mean | precision_std | precision_median | precision_min | precision_max | sensitivity_mean | sensitivity_std | sensitivity_median | sensitivity_min | sensitivity_max | f_measure_mean | f_measure_std | f_measure_median | f_measure_min | f_measure_max |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| str | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 |
| "hc" | 0.786486 | 0.066836 | 0.8039 | 0.6739 | 0.8684 | 0.905393 | 0.035058 | 0.91965 | 0.8308 | 0.9402 | 0.841286 | 0.053117 | 0.85785 | 0.7476 | 0.9026 |
| "lc" | 0.394586 | 0.09414 | 0.3839 | 0.2879 | 0.5593 | 0.461543 | 0.03951 | 0.46995 | 0.3936 | 0.5116 | 0.423086 | 0.06925 | 0.4235 | 0.3338 | 0.5344 |
In [29]:
def prepare_indel_performance_data(
ont_stats: pl.DataFrame,
metrics: Tuple[str, ...],
complexities: Tuple[str, ...],
) -> pl.DataFrame:
"""
Prepare data for indel performance visualization.
Args:
ont_stats: ONT statistics DataFrame
metrics: Metrics to plot
complexities: Complexity levels to plot
Returns:
pl.DataFrame: Prepared data for plotting
"""
plot_data = []
# Map plotting metric names to the underlying column name keys
metric_mapping = {
"Precision": "precision",
"Sensitivity": "sensitivity",
"F-measure": "f_measure",
}
for complexity in complexities:
for metric in metrics:
try:
col_name = f"{metric_mapping[metric]}_mean"
# Check if we have data for this complexity
ont_filtered = ont_stats.filter(pl.col("complexity") == complexity)
if ont_filtered.height == 0:
logger.warning(f"No data available for {complexity} {metric}")
continue
ont_value = ont_filtered.get_column(col_name)[0]
plot_data.append(
{
"Complexity": complexity.upper(),
"Metric": metric,
"Technology": "long-read",
"Value": ont_value,
"Significance": "", # No significance for indels as we only compare ONT
}
)
except Exception as e:
logger.error(
f"Error preparing data for {complexity} {metric}: {str(e)}"
)
continue
return pl.DataFrame(plot_data)
def plot_indel_performance_metrics(
plot_data_df: pl.DataFrame,
figsize: Tuple[int, int] = (14, 6),
dpi: int = 300,
ylim: Tuple[float, float] = (0, 1),
title: str = "Indel Performance",
metrics: Tuple[str, ...] = ("Precision", "Sensitivity", "F-measure"),
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Create a performance plot for indel detection.
Args:
plot_data_df: Prepared data DataFrame from prepare_indel_performance_data
figsize: Figure size in inches
dpi: Figure resolution
ylim: Y-axis limits
title: Plot title
metrics: Metrics to plot
gs: GridSpec for subplot placement
Returns:
Optional[plt.Figure]: Figure object if gs is None, None otherwise
"""
try:
if gs is None:
fig, axes = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
axes = [fig.add_subplot(gs[0, i]) for i in range(2)]
complexities = ["HC", "LC"]
for i, complexity in enumerate(complexities):
complexity_data = plot_data_df.filter(pl.col("Complexity") == complexity)
bars = sns.barplot(
data=complexity_data,
x="Metric",
y="Value",
errorbar=None,
ax=axes[i],
)
# Add value labels on top of each bar
for p in bars.patches:
value = p.get_height()
if value > 0:
axes[i].annotate(
f"{value:.3f}",
(p.get_x() + p.get_width() / 2.0, value),
ha="center",
va="bottom",
fontsize=8,
rotation=0,
)
complexity_label = "High" if complexity == "HC" else "Low"
axes[i].set_title(f"{complexity_label} Complexity", pad=15)
axes[i].set_ylim(ylim)
axes[i].set_xlabel("")
axes[i].set_ylabel("Performance")
if gs is None:
plt.tight_layout()
return fig
return None
except Exception as e:
logger.error(f"Error creating performance plots: {str(e)}")
raise
indel_performance_data = prepare_indel_performance_data(
ont_stats=indel_ont_stats,
metrics=("Precision", "Sensitivity", "F-measure"),
complexities=("hc", "lc"),
)
indel_performance_plot = plot_indel_performance_metrics(
plot_data_df=indel_performance_data,
)
2. Error Analysis¶
In [30]:
@dataclass
class IndelSizeMetrics:
"""Data class for indel size-based metrics."""
size: int
complexity: str
true_pos_baseline: int
true_pos_call: int
false_pos: int
false_neg: int
precision: float
sensitivity: float
f_measure: float
def get_vcfeval_indel_size_paths(
sample_id: str, size: int, complexity: str, base_dir: Path
) -> Tuple[Path, Path, Path, Path, Path]:
"""
Generate paths for size-specific indel VCF evaluation files.
Args:
sample_id: Sample identifier
size: Size of indels in base pairs
complexity: Genomic complexity region (hc/lc)
base_dir: Base directory for project data
Returns:
Tuple containing paths for summary.txt, tp-baseline.vcf.gz, tp.vcf.gz,
fp.vcf.gz, and fn.vcf.gz
"""
base_path = (
base_dir
/ "output"
/ "indel"
/ "rtg_vcfeval"
/ complexity
/ f"{size}bp"
/ f"{sample_id}.indel"
)
return (
base_path / "summary.txt",
base_path / "tp-baseline.vcf.gz",
base_path / "tp.vcf.gz",
base_path / "fp.vcf.gz",
base_path / "fn.vcf.gz",
)
def collect_indel_size_metrics(
sample_ids: pl.DataFrame,
complexities: List[str],
base_dir: Path,
) -> Dict[int, Dict[str, List[IndelSizeMetrics]]]:
"""
Collect indel metrics stratified by size from RTG vcfeval output.
Args:
sample_ids: DataFrame containing sample ID mappings
complexities: List of genomic complexity regions
base_dir: Base directory containing VCF files
Returns:
Nested dictionary of metrics organized by size and complexity
"""
metrics: Dict[int, Dict[str, List[IndelSizeMetrics]]] = defaultdict(
lambda: {comp: [] for comp in complexities}
)
for row in sample_ids.iter_rows(named=True):
sample_id = row["ont_id"]
for complexity in complexities:
indel_dir = base_dir / "output" / "indel" / "rtg_vcfeval" / complexity
if not indel_dir.exists():
continue
for size_dir in indel_dir.glob("*bp"):
try:
size = int(size_dir.name.replace("bp", ""))
except ValueError:
continue
summary_path = size_dir / f"{sample_id}.indel" / "summary.txt"
if not summary_path.is_file():
logger.warning(
f"Summary file not found for {sample_id}, size {size}, {complexity}"
)
continue
try:
summary = _read_rtg_summary(summary_path)
if summary:
metrics[size][complexity].append(
IndelSizeMetrics(
size=size, complexity=complexity, **summary
)
)
except Exception as e:
logger.error(
f"Error processing summary for {sample_id}, size {size}, "
f"complexity {complexity}: {str(e)}"
)
return metrics
def process_indel_size_metrics(
indel_metrics: Dict[int, Dict[str, List[IndelSizeMetrics]]]
) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""
Process indel metrics and calculate size-based statistics.
Args:
indel_metrics: Nested dictionary of metrics by size and complexity
Returns:
Tuple containing raw metrics DataFrame and aggregated statistics DataFrame
"""
all_data = []
for size, size_data in indel_metrics.items():
for complexity, metrics_list in size_data.items():
for metric in metrics_list:
all_data.append(
{
"size": metric.size,
"complexity": metric.complexity.upper(),
"precision": metric.precision,
"sensitivity": metric.sensitivity,
"f_measure": metric.f_measure,
"true_pos_baseline": metric.true_pos_baseline,
"true_pos_call": metric.true_pos_call,
"false_pos": metric.false_pos,
"false_neg": metric.false_neg,
}
)
metrics_df = pl.DataFrame(all_data)
stats_df = metrics_df.group_by(["size", "complexity"]).agg(
[
pl.col("precision").mean().alias("precision_mean"),
pl.col("precision").std().alias("precision_std"),
pl.col("sensitivity").mean().alias("sensitivity_mean"),
pl.col("sensitivity").std().alias("sensitivity_std"),
pl.col("f_measure").mean().alias("f_measure_mean"),
pl.col("f_measure").std().alias("f_measure_std"),
]
)
return metrics_df, stats_df
def prepare_indel_size_performance_data(
metrics_df: pl.DataFrame,
metrics: List[str] = ["precision", "sensitivity", "f_measure"],
) -> pl.DataFrame:
"""
Prepare indel performance data by size for visualization.
Args:
metrics_df: DataFrame containing raw metrics for each sample
metrics: List of performance metrics to analyze
Returns:
DataFrame containing processed performance data
"""
plot_data = []
for row in metrics_df.iter_rows(named=True):
for metric in metrics:
plot_data.append(
{
"size": row["size"],
"complexity": row["complexity"],
"metric": metric.capitalize(),
"value": row[metric],
}
)
return pl.DataFrame(plot_data)
def stretched_exponential(x, a, b, c, β):
"""
Stretched exponential function: a*exp(-(x/b)^β) + c
Parameters:
a: amplitude
b: characteristic time/length scale
c: vertical offset
β: stretching exponent (controls decay rate variation)
"""
return a * np.exp(-((x / b) ** β)) + c
def fit_and_get_ci(x, y, func=stretched_exponential, p0=None):
"""Fit curve and calculate 95% confidence intervals using median values"""
x = np.array(x)
y = np.array(y)
sort_idx = np.argsort(x)
x = x[sort_idx]
y = y[sort_idx]
p0 = p0 if p0 is not None else [1.0, 10.0, 0.0, 1.0]
try:
popt, pcov = curve_fit(
func,
x,
y,
p0=p0,
maxfev=5000,
bounds=([0, 0, -np.inf, 0], [np.inf, np.inf, np.inf, 10]),
)
perr = np.sqrt(np.diag(pcov))
x_smooth = np.linspace(1, 50, 100)
y_fit = func(x_smooth, *popt)
y_err = np.zeros(len(x_smooth))
for i in range(len(x_smooth)):
jac = np.zeros(4)
dx = 1e-6
for j in range(4):
params = list(popt)
params[j] += dx
jac[j] = (func(x_smooth[i], *params) - y_fit[i]) / dx
y_err[i] = np.sqrt(np.sum((jac * perr) ** 2))
return x_smooth, y_fit, y_err, popt, perr
except Exception as e:
logger.error(f"Error in curve fitting: {str(e)}")
x_smooth = np.linspace(1, 50, 100)
y_fit = np.zeros_like(x_smooth)
y_err = np.zeros_like(x_smooth)
popt = [0, 0, 0, 0]
perr = [0, 0, 0, 0]
return x_smooth, y_fit, y_err, popt, perr
def analyze_indel_size_performance(
metrics_df: pl.DataFrame,
indel_metrics: Dict[int, Dict[str, List[IndelSizeMetrics]]],
metrics: List[str] = ["precision", "sensitivity", "f_measure"],
) -> Tuple[pl.DataFrame, Dict]:
"""
Analyze indel performance by size categories and fit regression curves.
"""
# Create size categories
metrics_df = metrics_df.with_columns(
pl.when(pl.col("size") <= 5)
.then(pl.lit("1-5bp"))
.when(pl.col("size") <= 10)
.then(pl.lit("6-10bp"))
.when(pl.col("size") <= 20)
.then(pl.lit("11-20bp"))
.when(pl.col("size") <= 50)
.then(pl.lit("21-50bp"))
.alias("size_category")
)
# Count variants by size category and complexity
size_counts = {complexity: defaultdict(int) for complexity in ["hc", "lc"]}
for size, size_data in indel_metrics.items():
size_int = int(size)
for complexity, metrics_list in size_data.items():
for metric in metrics_list:
if size_int <= 5:
size_cat = "1-5bp"
elif size_int <= 10:
size_cat = "6-10bp"
elif size_int <= 20:
size_cat = "11-20bp"
elif size_int <= 50:
size_cat = "21-50bp"
else:
continue
size_counts[complexity.lower()][size_cat] += (
metric.true_pos_baseline + metric.false_pos + metric.false_neg
)
# Create summary statistics
summary_stats = []
for complexity in ["HC", "LC"]:
for size_cat in ["1-5bp", "6-10bp", "11-20bp", "21-50bp"]:
group = metrics_df.filter(
(pl.col("complexity") == complexity)
& (pl.col("size_category") == size_cat)
)
complexity_key = complexity.lower()
total_count = sum(size_counts[complexity_key].values())
category_count = size_counts[complexity_key].get(size_cat, 0)
stats_dict = {
"complexity": complexity,
"size_category": size_cat,
"count": category_count,
"proportion": (
(category_count / total_count) * 100 if total_count > 0 else 0
),
}
for metric in metrics:
metric_values = group.get_column(metric)
if len(metric_values) > 0:
metric_median = np.median(metric_values)
q1 = np.percentile(metric_values, 25)
q3 = np.percentile(metric_values, 75)
iqr = q3 - q1
else:
metric_median = float("nan")
iqr = float("nan")
stats_dict[f"{metric}_median"] = metric_median
stats_dict[f"{metric}_iqr"] = iqr
summary_stats.append(stats_dict)
summary_df = pl.DataFrame(summary_stats)
# Round values
summary_df = summary_df.with_columns(
[
pl.col("proportion").round(1).alias("proportion"),
*[
pl.col(f"{metric}_median").round(3).alias(f"{metric}_median")
for metric in metrics
],
*[
pl.col(f"{metric}_iqr").round(3).alias(f"{metric}_iqr")
for metric in metrics
],
]
)
# Fit regression curves
regression_results = {}
for metric in metrics:
metric_cap = metric.capitalize()
regression_results[metric_cap] = {}
for complexity in ["HC", "LC"]:
data = metrics_df.filter(pl.col("complexity") == complexity)
# Calculate median by size
median_by_size = data.group_by("size").agg(
pl.col(metric).median().alias("median_value")
)
x = median_by_size.get_column("size").to_numpy()
y = median_by_size.get_column("median_value").to_numpy()
x_smooth, y_fit, y_err, popt, perr = fit_and_get_ci(
x, y, stretched_exponential, p0=[1.0, 10.0, 0.0, 1.0]
)
regression_results[metric_cap][complexity] = {
"x_smooth": x_smooth,
"y_fit": y_fit,
"y_err": y_err,
"popt": popt,
"perr": perr,
}
# Calculate R-squared and p-value
if len(x) > 4: # Need more data points than parameters
residuals = y - stretched_exponential(x, *popt)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((y - np.mean(y)) ** 2)
r_squared = 1 - (ss_res / ss_tot)
n = len(x)
p = 4 # Number of parameters
f_stat = (r_squared / p) / ((1 - r_squared) / (n - p - 1))
p_value = 1 - stats.f.cdf(f_stat, p, n - p - 1)
logger.info(f"{metric_cap} - {complexity} Region:")
logger.info(f"R-squared: {r_squared:.3f} (p-value: {p_value:.3e})")
return summary_df, regression_results
def plot_indel_size_performance(
plot_data: pl.DataFrame,
regression_results: Dict,
figsize: Tuple[int, int] = (15, 10),
dpi: int = 300,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[Tuple[plt.Figure, np.ndarray]]:
"""
Create performance plots for indel size analysis.
Args:
plot_data (pl.DataFrame): DataFrame containing prepared performance data.
regression_results (Dict): Dictionary with regression fit data for each metric.
figsize (Tuple[int, int], optional): Figure size in inches. Defaults to (15, 10).
dpi (int, optional): Figure resolution. Defaults to 300.
gs (Optional[gridspec.GridSpec], optional): GridSpec for subplot placement.
If None, creates a standalone figure. Defaults to None.
Returns:
Optional[Tuple[plt.Figure, np.ndarray]]: If `gs` is None, returns the figure and axes array.
If `gs` is provided, returns None (the plots are added as subfigures).
"""
metrics = ["Precision", "Sensitivity", "F_measure"]
if gs is None:
fig, axes = plt.subplots(3, 1, figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
axes = np.array([fig.add_subplot(gs[i, 0]) for i in range(3)])
colors = sns.color_palette()
for i, metric in enumerate(metrics):
filtered = plot_data.filter(pl.col("metric") == metric).with_columns(
pl.col("size").cast(pl.Int64)
)
sizes = filtered["size"].to_list()
values = filtered["value"].to_list()
complexities = filtered["complexity"].to_list()
sns.boxplot(x=sizes, y=values, hue=complexities, ax=axes[i], width=0.8)
complexity_labels = {"HC": "High Complexity", "LC": "Low Complexity"}
for j, complexity in enumerate(["HC", "LC"]):
if (
metric in regression_results
and complexity in regression_results[metric]
):
results = regression_results[metric][complexity]
axes[i].plot(
results["x_smooth"],
results["y_fit"],
"-",
color=colors[j],
label=f"{complexity_labels[complexity]} line of best fit",
alpha=0.6,
linewidth=2,
)
axes[i].set_xticks(np.arange(-0.5, 50.5, 1), minor=True)
axes[i].grid(axis="x", linestyle="-", alpha=0.7, which="minor")
axes[i].set_axisbelow(True)
axes[i].set_title("F-measure" if metric == "F_measure" else metric)
axes[i].set_xlabel("Indel Size (bp)")
axes[i].set_ylabel("Performance")
axes[i].set_xlim(-0.5, 49.5)
axes[i].set_ylim(0, 1)
if i == 0:
handles, labels = axes[i].get_legend_handles_labels()
labels = [
"High Complexity" if l == "HC" else "Low Complexity" if l == "LC" else l
for l in labels
]
if gs is None:
legend = axes[i].legend(
handles,
labels,
title="Region",
bbox_to_anchor=(1, 1),
loc="upper left",
)
legend.get_title().set_weight("bold")
else:
legend = axes[i].legend(
handles,
labels,
title="Region",
loc="lower left",
)
legend.get_title().set_weight("bold")
else:
legend = axes[i].get_legend()
if legend is not None:
legend.remove()
if gs is None:
plt.tight_layout()
return fig
return None
return gs
indel_size_metrics = collect_indel_size_metrics(
sample_ids, list(indel_config.complexities), indel_config.base_dir
)
indel_raw_metrics_df, indel_aggregated_stats_df = process_indel_size_metrics(
indel_size_metrics
)
indel_size_plot_data = prepare_indel_size_performance_data(indel_raw_metrics_df)
indel_size_perf_summary_df, indel_size_perf_regression_results = (
analyze_indel_size_performance(indel_raw_metrics_df, indel_size_metrics)
)
indel_size_perf_plot = plot_indel_size_performance(
indel_size_plot_data, indel_size_perf_regression_results
)
with pl.Config(tbl_rows=len(indel_size_perf_summary_df)):
display(indel_size_perf_summary_df)
__main__ - INFO - Precision - HC Region:
__main__ - INFO - R-squared: 0.849 (p-value: 1.110e-16)
__main__ - INFO - Precision - LC Region:
__main__ - INFO - R-squared: 0.690 (p-value: 5.951e-11)
__main__ - INFO - Sensitivity - HC Region:
__main__ - INFO - R-squared: 0.670 (p-value: 2.362e-10)
__main__ - INFO - Sensitivity - LC Region:
__main__ - INFO - R-squared: -0.000 (p-value: 1.000e+00)
__main__ - INFO - F_measure - HC Region:
__main__ - INFO - R-squared: 0.913 (p-value: 1.110e-16)
__main__ - INFO - F_measure - LC Region:
__main__ - INFO - R-squared: 0.575 (p-value: 6.025e-08)
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
shape: (8, 10)
| complexity | size_category | count | proportion | precision_median | precision_iqr | sensitivity_median | sensitivity_iqr | f_measure_median | f_measure_iqr |
|---|---|---|---|---|---|---|---|---|---|
| str | str | i64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 |
| "HC" | "1-5bp" | 3314188 | 89.5 | 0.826 | 0.071 | 0.918 | 0.04 | 0.869 | 0.066 |
| "HC" | "6-10bp" | 215329 | 5.8 | 0.772 | 0.03 | 0.9 | 0.067 | 0.834 | 0.039 |
| "HC" | "11-20bp" | 116502 | 3.1 | 0.798 | 0.037 | 0.932 | 0.063 | 0.862 | 0.044 |
| "HC" | "21-50bp" | 56536 | 1.5 | 0.624 | 0.184 | 0.885 | 0.097 | 0.724 | 0.156 |
| "LC" | "1-5bp" | 12774954 | 83.5 | 0.445 | 0.116 | 0.483 | 0.109 | 0.454 | 0.124 |
| "LC" | "6-10bp" | 1374227 | 9.0 | 0.441 | 0.173 | 0.59 | 0.173 | 0.501 | 0.181 |
| "LC" | "11-20bp" | 751250 | 4.9 | 0.407 | 0.15 | 0.577 | 0.156 | 0.478 | 0.16 |
| "LC" | "21-50bp" | 392160 | 2.6 | 0.25 | 0.093 | 0.571 | 0.085 | 0.35 | 0.093 |
3. Size Distribution Analysis¶
In [31]:
def _classify_indel(ref: str, alt: str) -> Tuple[str, int]:
"""
Classify an indel as insertion or deletion and determine its length.
Args:
ref: Reference allele
alt: Alternate allele
Returns:
Tuple containing indel category (insertion/deletion) and length
"""
indel_type = "insertion" if len(alt) > len(ref) else "deletion"
indel_length = abs(len(alt) - len(ref))
return indel_type, indel_length
def analyze_indel_size_distribution(
sample_ids: pl.DataFrame,
complexities: List[str],
) -> Dict[str, Dict[str, Dict[str, Dict[str, int]]]]:
"""
Analyze indel size distribution using original VCF files.
Args:
sample_ids: DataFrame containing sample ID mappings.
complexities: List of genomic complexity regions.
Returns:
Nested dictionary containing size distributions organized by:
- First level: complexity region (hc/lc).
- Second level: technology (ont/illumina).
- Third level: indel type (insertion/deletion).
- Fourth level: dictionary mapping sizes to counts.
"""
distributions: Dict[str, Dict[str, Dict[str, DefaultDict[str, int]]]] = {
comp: {
tech: {
indel_type: defaultdict(int) for indel_type in ["insertion", "deletion"]
}
for tech in ["ont", "illumina"]
}
for comp in complexities
}
for row in sample_ids.iter_rows(named=True):
sample_id = row["ont_id"]
for complexity in complexities:
base_dir = (
indel_config.base_dir / "output" / "indel" / "rtg_vcfeval" / complexity
)
# Extract indel sizes dynamically from available directories
for indel_size_dir in base_dir.iterdir():
if not indel_size_dir.is_dir() or not indel_size_dir.name.endswith(
"bp"
):
continue # Skip non-indel size directories
vcf_dir = indel_size_dir / f"{sample_id}.indel"
query_vcf = vcf_dir / "query.vcf.gz"
truth_vcf = vcf_dir / "truth.vcf.gz"
for vcf_path, tech in [(query_vcf, "ont"), (truth_vcf, "illumina")]:
if not vcf_path.exists():
logger.warning(f"VCF file not found: {vcf_path}")
continue
try:
with pysam.VariantFile(str(vcf_path)) as vcf:
for record in vcf:
if len(record.ref) == 1 and len(record.alts[0]) == 1:
continue
indel_type, length = _classify_indel(
record.ref, record.alts[0]
)
distributions[complexity][tech][indel_type][length] += 1
except Exception as e:
logger.error(f"Error processing {vcf_path}: {str(e)}")
return distributions
def prepare_indel_size_distribution_data(
distributions: Dict[str, Dict[str, Dict[str, Dict[str, int]]]]
) -> pl.DataFrame:
"""
Prepare indel size distribution data for visualization.
Args:
distributions: Nested dictionary containing size distributions organized by
complexity, technology, and indel type
Returns:
DataFrame containing processed distribution data with columns for Complexity,
Technology, Indel Type, Size, and Percentage
"""
plot_data = []
tech_display = {"ont": "ONT", "illumina": "Illumina"}
complexity_display = {"hc": "High Complexity", "lc": "Low Complexity"}
for complexity, tech_data in distributions.items():
for tech, type_data in tech_data.items():
for indel_type, size_data in type_data.items():
total = sum(size_data.values())
if total == 0:
continue
max_size = max(size_data.keys()) if size_data else 0
for size in range(1, max_size + 1):
count = size_data.get(size, 0)
plot_data.append(
{
"Complexity": complexity_display[complexity],
"Technology": tech_display[tech],
"Indel Type": indel_type,
"Size": size,
"Percentage": (count / total) * 100 if total > 0 else 0,
}
)
return pl.DataFrame(plot_data)
def plot_indel_size_distributions(
df: pl.DataFrame,
figsize: Tuple[int, int] = (12, 8),
dpi: int = 300,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Create plots showing indel size distributions across technologies and complexities.
Args:
df: DataFrame containing distribution data
figsize: Figure size in inches
dpi: Figure resolution
gs: GridSpec for subplot placement. If None, creates a standalone figure.
Returns:
Optional[plt.Figure]: If gs is None, returns the figure.
If gs is provided, returns None (plots are added as subfigures).
"""
try:
if gs is None:
fig, axes = plt.subplots(2, 2, figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
axes = np.array(
[[plt.subplot(gs[i, j]) for j in range(2)] for i in range(2)]
)
indel_types = ["insertion", "deletion"]
complexities = ["High Complexity", "Low Complexity"]
colors = sns.color_palette("colorblind")
for i, indel_type in enumerate(indel_types):
for j, complexity in enumerate(complexities):
subset = df.filter(
(pl.col("Indel Type") == indel_type)
& (pl.col("Complexity") == complexity)
)
for k, tech in enumerate(["ONT", "Illumina"]):
tech_data = subset.filter(pl.col("Technology") == tech)
if tech_data.height > 0:
tech_label = "Long-read" if tech == "ONT" else "Short-read"
sns.lineplot(
data=tech_data,
x="Size",
y="Percentage",
label=tech_label,
ax=axes[i, j],
marker="o",
markersize=4,
color=colors[k],
)
axes[i, j].set_title(f"{complexity} - {indel_type.capitalize()}s")
axes[i, j].set_xlabel("Indel Size (bp)")
axes[i, j].set_ylabel("Proportion of Indels (%)")
# Check if we're in the top-right plot (0, 1) and handle legend
if (i, j) == (0, 1):
if gs is None:
axes[i, j].legend(
title="Technology", bbox_to_anchor=(1, 1), loc="upper left"
)
axes[i, j].get_legend().get_title().set_weight("bold")
else:
axes[i, j].legend(title="Technology", loc="upper right")
axes[i, j].get_legend().get_title().set_weight("bold")
else:
# Remove legend for all other subplots
legend = axes[i, j].get_legend()
if legend is not None:
legend.remove()
if gs is None:
plt.tight_layout()
return fig
return None
except Exception as e:
logger.error(f"Error creating indel size distribution plots: {str(e)}")
return None
def compare_size_distributions(
distributions: Dict[str, Dict[str, Dict[str, Dict[str, int]]]]
) -> pl.DataFrame:
"""
Perform statistical comparison of indel size distributions between ONT and Illumina
using the Kolmogorov-Smirnov test.
Args:
distributions: Nested dictionary containing indel size distributions organized by
complexity, technology, and indel type
Returns:
DataFrame containing statistical test results with columns:
- Complexity: Genomic complexity region (High/Low Complexity)
- Indel Type: Type of indel (Insertion/Deletion)
- KS Statistic: Kolmogorov-Smirnov test statistic
- p-value: Raw p-value from KS test
- Adjusted p-value: FDR-corrected p-value
"""
results = []
try:
for complexity in distributions:
for indel_type in ["insertion", "deletion"]:
# Expand distribution to actual size arrays
ont_sizes = []
illumina_sizes = []
for size, count in distributions[complexity]["ont"][indel_type].items():
ont_sizes.extend([size] * count)
for size, count in distributions[complexity]["illumina"][
indel_type
].items():
illumina_sizes.extend([size] * count)
if not ont_sizes or not illumina_sizes:
logger.warning(
f"Skipping KS test for {complexity}/{indel_type} due to empty data"
)
continue
ks_stat, p_val = stats.ks_2samp(ont_sizes, illumina_sizes)
results.append(
{
"Complexity": (
"High Complexity"
if complexity == "hc"
else "Low Complexity"
),
"Indel Type": indel_type.capitalize(),
"KS Statistic": ks_stat,
"p-value": p_val,
}
)
if not results:
logger.warning("No data available for statistical comparison")
return pl.DataFrame(
schema={
"Complexity": pl.Utf8,
"Indel Type": pl.Utf8,
"KS Statistic": pl.Float64,
"p-value": pl.Float64,
"Adjusted p-value": pl.Float64,
}
)
results_df = pl.DataFrame(results)
# Extract p-values and apply FDR correction
p_values = results_df.get_column("p-value").to_numpy()
_, adjusted_p, _, _ = multipletests(p_values, method="fdr_bh")
# Add adjusted p-values to the DataFrame
results_df = results_df.with_columns(pl.Series("Adjusted p-value", adjusted_p))
return results_df
except Exception as e:
logger.error(f"Error performing statistical comparison: {str(e)}")
return pl.DataFrame(
schema={
"Complexity": pl.Utf8,
"Indel Type": pl.Utf8,
"KS Statistic": pl.Float64,
"p-value": pl.Float64,
"Adjusted p-value": pl.Float64,
}
)
indel_distributions = analyze_indel_size_distribution(
sample_ids, list(indel_config.complexities)
)
distribution_data = prepare_indel_size_distribution_data(indel_distributions)
indel_size_dist_plot = plot_indel_size_distributions(distribution_data)
statistical_results = compare_size_distributions(indel_distributions)
print("\nStatistical Test Results (Kolmogorov-Smirnov test with FDR correction):")
with pl.Config(tbl_rows=len(statistical_results)):
display(statistical_results)
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/46bp/A046_12.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/30bp/A046_12.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/7bp/A046_12.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/42bp/A046_12.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/5bp/A046_12.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/26bp/A046_12.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/26bp/A046_12.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/20bp/A046_12.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/47bp/A046_12.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A046_12.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A046_12.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/33bp/A048_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/22bp/A048_09.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/19bp/A048_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/13bp/A048_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/18bp/A048_09.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/5bp/A048_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/39bp/A048_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/2bp/A048_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/26bp/A048_09.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/19bp/A048_09.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A048_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/34bp/A079_07.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/44bp/A079_07.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/32bp/A079_07.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/39bp/A079_07.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/39bp/A079_07.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/42bp/A079_07.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A079_07.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A079_07.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/11bp/A081_91.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/8bp/A081_91.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/4bp/A081_91.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/40bp/A081_91.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A081_91.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A081_91.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/24bp/A085_00.indel/query.vcf.gz.tbi [W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/24bp/A085_00.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/45bp/A085_00.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/25bp/A085_00.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/37bp/A085_00.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A085_00.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A085_00.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/17bp/A097_92.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/1bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/37bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/33bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/18bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/36bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/35bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/39bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/29bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/2bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/26bp/A097_92.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/43bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/39bp/A149_01.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/2bp/A149_01.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/26bp/A149_01.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/1bp/A149_01.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/24bp/A149_01.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/26bp/A149_01.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/38bp/A153_01.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/27bp/A153_01.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/9bp/A153_01.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/40bp/A153_01.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/17bp/A153_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/6bp/A153_06.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/6bp/A153_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/32bp/A153_06.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/24bp/A153_06.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/24bp/A153_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/34bp/A153_06.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/2bp/A153_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/6bp/A153_06.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/9bp/A154_04.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/18bp/A154_04.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/45bp/A154_04.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/45bp/A154_04.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/22bp/A154_04.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/16bp/A154_04.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/23bp/A154_04.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/10bp/A154_04.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/10bp/A154_04.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/25bp/A154_04.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/14bp/A154_04.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/43bp/A154_04.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A154_04.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/32bp/A154_06.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/34bp/A154_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/27bp/A154_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/1bp/A154_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/49bp/A154_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/30bp/A154_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/23bp/A154_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/1bp/A157_02.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/34bp/A157_02.indel/truth.vcf.gz.tbi [W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/11bp/A157_02.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/40bp/A157_02.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/19bp/A157_02.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/3bp/A160_96.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/3bp/A160_96.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/6bp/A160_96.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/33bp/A160_96.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/39bp/A160_96.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/45bp/A160_96.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/12bp/A160_96.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/27bp/A162_09.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/47bp/A162_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/1bp/A162_09.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/35bp/A162_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/41bp/A162_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/27bp/A162_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/6bp/A162_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/23bp/A162_09.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A162_09.indel/truth.vcf.gz.tbi
Statistical Test Results (Kolmogorov-Smirnov test with FDR correction):
shape: (4, 5)
| Complexity | Indel Type | KS Statistic | p-value | Adjusted p-value |
|---|---|---|---|---|
| str | str | f64 | f64 | f64 |
| "High Complexity" | "Insertion" | 0.012716 | 2.1192e-116 | 2.1192e-116 |
| "High Complexity" | "Deletion" | 0.012852 | 2.3141e-118 | 3.0854e-118 |
| "Low Complexity" | "Insertion" | 0.027755 | 0.0 | 0.0 |
| "Low Complexity" | "Deletion" | 0.011657 | 2.3832e-318 | 4.7664e-318 |
4. Combined Plots¶
In [32]:
def create_combined_indel_metrics_plot(
indel_ont_stats: pl.DataFrame,
distribution_data: pl.DataFrame,
indel_size_plot_data: pl.DataFrame,
regression_results: Dict,
metrics: Tuple[str, ...] = ("Precision", "Sensitivity", "F-measure"),
complexities: Tuple[str, ...] = ("hc", "lc"),
figsize: Tuple[int, int] = (12, 16),
dpi: int = 300,
) -> plt.Figure:
"""
Create a combined figure showing indel performance metrics, size distributions,
and size-based performance analysis.
Args:
indel_ont_stats: DataFrame containing ONT indel statistics
distribution_data: DataFrame containing indel size distribution data
indel_size_plot_data: DataFrame containing size-based performance data
regression_results: Dictionary containing regression analysis results
metrics: Tuple of performance metrics to plot
complexities: Tuple of complexity regions to analyze
figsize: Figure dimensions (width, height)
dpi: Figure resolution
Returns:
Combined figure object containing all plots
Raises:
ValueError: If input data is invalid or missing
Exception: If there's an error creating the combined plot
"""
try:
# Input validation
if indel_ont_stats.height == 0 or distribution_data.height == 0:
raise ValueError("Input DataFrames cannot be empty")
# Create figure with GridSpec
fig = plt.figure(figsize=figsize, dpi=dpi)
gs = fig.add_gridspec(4, 2, height_ratios=[0.3, 0.3, 0.3, 1.3])
# Section A & B: Performance Metrics
performance_data = prepare_indel_performance_data(
ont_stats=indel_ont_stats, metrics=metrics, complexities=complexities
)
gs_perf = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs[0, :])
plot_indel_performance_metrics(plot_data_df=performance_data, gs=gs_perf)
# Add panel labels for A & B
for i, ax in enumerate(fig.axes[:2]):
label = chr(ord("A") + i)
ax.text(
-0.1,
1.05,
label,
transform=ax.transAxes,
fontsize=12,
fontweight="bold",
)
# Section C-F: Size Distributions
gs_dist = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[1:3, :])
plot_indel_size_distributions(df=distribution_data, gs=gs_dist)
# Add panel labels for C-F
for i, ax in enumerate(fig.axes[2:6]):
label = chr(ord("C") + i)
ax.text(
-0.1,
1.05,
label,
transform=ax.transAxes,
fontsize=12,
fontweight="bold",
)
# Section G-I: Size Performance Analysis
gs_size = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[3, :])
plot_indel_size_performance(
plot_data=indel_size_plot_data,
regression_results=regression_results,
gs=gs_size,
)
# Add panel labels for G-I
for i, ax in enumerate(fig.axes[6:]):
label = chr(ord("G") + i)
ax.text(
-0.045,
1.05,
label,
transform=ax.transAxes,
fontsize=12,
fontweight="bold",
)
# Adjust layout
fig.set_constrained_layout(True)
return fig
except Exception as e:
logger.error(f"Error creating combined indel metrics plot: {str(e)}")
raise
combined_indel_fig = create_combined_indel_metrics_plot(
indel_ont_stats=indel_ont_stats,
distribution_data=distribution_data,
indel_size_plot_data=indel_size_plot_data,
regression_results=indel_size_perf_regression_results,
)
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
In [33]:
def create_snv_multiplexing_comparison_plot(
metrics_df: pl.DataFrame,
multiplexing_df: pl.DataFrame,
config: SNVAnalysisConfig,
figsize: Tuple[int, int] = (14, 8),
dpi: int = 300,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Create violin plots comparing SNV performance metrics between multiplexed and
singleplexed samples.
Args:
metrics_df: DataFrame containing performance metrics
multiplexing_df: DataFrame containing multiplexing information
figsize: Figure size in inches
dpi: Figure resolution
metrics: Performance metrics to plot
complexities: Complexity levels to compare
gs: GridSpec for subplot placement. If None, creates standalone figure
Returns:
Optional[plt.Figure]: If gs is None, returns the figure.
If gs is provided, returns None (plots are added as subfigures).
Raises:
ValueError: If required columns are missing from input DataFrames
"""
try:
if gs is None:
fig, axes = plt.subplots(2, 3, figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
axes = np.array(
[[plt.subplot(gs[i, j]) for j in range(3)] for i in range(2)]
)
fig.suptitle(
"SNV Performance Metrics vs Multiplexing",
)
merged_df = metrics_df.join(
multiplexing_df.select(["sample", "multiplexing"]),
left_on="sample_id",
right_on="sample",
how="inner",
)
for row, complexity in enumerate(config.complexities):
complexity_data = merged_df.filter(pl.col("complexity") == complexity)
for col, metric in enumerate(config.metrics_to_test):
ax = axes[row, col]
sns.violinplot(
x="multiplexing",
y=metric,
data=complexity_data,
ax=ax,
hue="multiplexing",
)
ax.set_xlabel("Multiplexing")
ax.set_ylabel(metric.capitalize().replace("_", "-"))
y_mins = []
y_maxs = []
for ax in axes[row, :]:
y_mins.append(ax.get_ylim()[0])
y_maxs.append(ax.get_ylim()[1])
y_min = min(y_mins)
y_max = max(y_maxs)
for ax in axes[row, :]:
ax.set_ylim(y_min, y_max)
axes[row, 0].annotate(
"High Complexity" if row == 0 else "Low Complexity",
xy=(-0.17, 0.5),
xycoords="axes fraction",
fontweight="bold",
ha="right",
va="center",
rotation=90,
)
if gs is None:
plt.tight_layout()
return fig
return None
except Exception as e:
logger.error(f"Error creating performance comparison plots: {str(e)}")
raise
performance_comparison_fig = create_snv_multiplexing_comparison_plot(
metrics_df=snv_rtg_metrics_dfs["ont"],
multiplexing_df=nanoplot_qc_metrics_df,
config=snv_config,
)
In [34]:
def create_indel_multiplexing_comparison_plot(
metrics_df: pl.DataFrame,
multiplexing_df: pl.DataFrame,
config: IndelAnalysisConfig,
figsize: Tuple[int, int] = (14, 8),
dpi: int = 300,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Create violin plots comparing indel performance metrics between multiplexed and
singleplexed samples.
Args:
metrics_df: DataFrame containing indel performance metrics
multiplexing_df: DataFrame containing multiplexing information
config: Analysis configuration
figsize: Figure size in inches
dpi: Figure resolution
gs: GridSpec for subplot placement. If None, creates standalone figure
Returns:
Optional[plt.Figure]: If gs is None, returns the figure.
If gs is provided, returns None (plots are added as subfigures).
Raises:
ValueError: If required columns are missing from input DataFrames
"""
try:
if gs is None:
fig, axes = plt.subplots(2, 3, figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
axes = np.array(
[[plt.subplot(gs[i, j]) for j in range(3)] for i in range(2)]
)
fig.suptitle(
"Indel Performance Metrics vs Multiplexing",
)
merged_df = metrics_df.join(
multiplexing_df.select(["sample", "multiplexing"]),
left_on="sample_id",
right_on="sample",
how="inner",
)
for row, complexity in enumerate(config.complexities):
complexity_data = merged_df.filter(pl.col("complexity") == complexity)
for col, metric in enumerate(config.metrics_to_test):
ax = axes[row, col]
sns.violinplot(
x="multiplexing",
y=metric,
data=complexity_data,
ax=ax,
hue="multiplexing",
)
ax.set_xlabel("Multiplexing")
ax.set_ylabel(metric.capitalize().replace("_", "-"))
y_mins = []
y_maxs = []
for ax in axes[row, :]:
y_mins.append(ax.get_ylim()[0])
y_maxs.append(ax.get_ylim()[1])
y_min = min(y_mins)
y_max = max(y_maxs)
for ax in axes[row, :]:
ax.set_ylim(y_min, y_max)
axes[row, 0].annotate(
"High Complexity" if row == 0 else "Low Complexity",
xy=(-0.17, 0.5),
xycoords="axes fraction",
fontweight="bold",
ha="right",
va="center",
rotation=90,
)
if gs is None:
plt.tight_layout()
return fig
return None
except Exception as e:
logger.error(f"Error creating indel performance comparison plots: {str(e)}")
raise
indel_performance_comparison_fig = create_indel_multiplexing_comparison_plot(
metrics_df=indel_rtg_metrics_dfs["ont"],
multiplexing_df=nanoplot_qc_metrics_df,
config=indel_config,
)
In [35]:
def create_combined_multiplexing_comparison_plot(
snv_metrics_df: pl.DataFrame,
indel_metrics_df: pl.DataFrame,
multiplexing_df: pl.DataFrame,
snv_config: SNVAnalysisConfig,
indel_config: IndelAnalysisConfig,
figsize: Tuple[int, int] = (14, 16),
dpi: int = 300,
) -> plt.Figure:
"""
Create a combined figure showing both SNV and indel performance metrics
comparisons between multiplexed and singleplexed samples.
Args:
snv_metrics_df: DataFrame containing SNV performance metrics
indel_metrics_df: DataFrame containing indel performance metrics
multiplexing_df: DataFrame containing multiplexing information
snv_config: Configuration for SNV analysis
indel_config: Configuration for indel analysis
figsize: Figure dimensions (width, height)
dpi: Figure resolution
Returns:
Combined figure object containing all plots
Raises:
ValueError: If input data is invalid or missing
Exception: If there's an error creating the combined plot
"""
try:
# Input validation
if snv_metrics_df.height == 0 or indel_metrics_df.height == 0:
raise ValueError("Input DataFrames cannot be empty")
# Create figure with GridSpec
fig = plt.figure(figsize=figsize, dpi=dpi)
gs = fig.add_gridspec(
3, 1, height_ratios=[1, 0.03, 1]
) # Add middle spacing row
# SNV section
gs_snv = gridspec.GridSpecFromSubplotSpec(2, 3, subplot_spec=gs[0])
create_snv_multiplexing_comparison_plot(
metrics_df=snv_metrics_df,
multiplexing_df=multiplexing_df,
config=snv_config,
gs=gs_snv,
)
# Add SNV section title
fig.text(
0.5,
0.99,
"SNV Performance Metrics vs Multiplexing",
ha="center",
fontsize=12,
fontweight="bold",
)
# Indel section
gs_indel = gridspec.GridSpecFromSubplotSpec(
2, 3, subplot_spec=gs[2]
) # Move Indel to row 3 for padding
create_indel_multiplexing_comparison_plot(
metrics_df=indel_metrics_df,
multiplexing_df=multiplexing_df,
config=indel_config,
gs=gs_indel,
)
# Add Indel section title
fig.text(
0.5,
0.48,
"Indel Performance Metrics vs Multiplexing",
ha="center",
fontsize=12,
fontweight="bold",
)
# Add panel labels (one per row)
for i, ax in enumerate(
fig.axes[::3]
): # Step by 3 to label the first plot in each row
label = chr(ord("A") + i)
ax.text(
-0.1,
1.05,
label,
transform=ax.transAxes,
fontsize=12,
fontweight="bold",
)
# Adjust layout
fig.suptitle("")
fig.set_constrained_layout(True)
return fig
except Exception as e:
logger.error(f"Error creating combined variant comparison plot: {str(e)}")
raise
combined_variant_fig = create_combined_multiplexing_comparison_plot(
snv_metrics_df=snv_rtg_metrics_dfs["ont"],
indel_metrics_df=indel_rtg_metrics_dfs["ont"],
multiplexing_df=nanoplot_qc_metrics_df,
snv_config=snv_config,
indel_config=indel_config,
)
In [36]:
@dataclass
class MetricStats:
"""Class for storing statistical metrics.
Attributes:
mean (float): The mean value of the metric
std (float): The standard deviation of the metric
median (float): The median value of the metric
"""
mean: float
std: float
median: float
@dataclass
class VariantStats:
"""Class for storing variant statistics.
Attributes:
precision (MetricStats): Precision metrics including mean, std, and median
sensitivity (MetricStats): Sensitivity metrics including mean, std, and median
f_measure (MetricStats): F-measure metrics including mean, std, and median
multiplexing (str): Type of multiplexing (singleplex or multiplex)
variant_type (str): Type of variant (SNV or Indel)
complexity (str): Complexity level (hc or lc)
"""
precision: MetricStats
sensitivity: MetricStats
f_measure: MetricStats
multiplexing: str
variant_type: str
complexity: str
def create_variant_multiplexing_stats(
snv_df: pl.DataFrame, indel_df: pl.DataFrame, multiplexing_df: pl.DataFrame
) -> List[VariantStats]:
"""Creates variant multiplexing statistics by combining SNV, Indel, and multiplexing data.
Args:
snv_df (pl.DataFrame): DataFrame containing SNV metrics
indel_df (pl.DataFrame): DataFrame containing Indel metrics
multiplexing_df (pl.DataFrame): DataFrame containing multiplexing information
Returns:
List[VariantStats]: List of VariantStats objects containing calculated statistics
for different combinations of variant types and complexity levels
Raises:
ValueError: If required columns are missing in the input DataFrames
"""
def calculate_stats(
df: pl.DataFrame, variant_type: str, complexity: str
) -> List[VariantStats]:
"""Calculates statistics for a specific variant type and complexity level.
Args:
df (pl.DataFrame): DataFrame containing variant metrics
variant_type (str): Type of variant (SNV or Indel)
complexity (str): Complexity level (hc or lc)
Returns:
List[VariantStats]: List of VariantStats objects for the specified variant type
and complexity level
"""
merged_df = df.filter(pl.col("complexity") == complexity).join(
multiplexing_df.select(["sample", "multiplexing"]),
left_on="sample_id",
right_on="sample",
how="inner",
)
stats_df = merged_df.group_by("multiplexing").agg(
[
pl.col("precision").mean().alias("Precision_mean"),
pl.col("precision").std().alias("Precision_std"),
pl.col("precision").median().alias("Precision_median"),
pl.col("sensitivity").mean().alias("Sensitivity_mean"),
pl.col("sensitivity").std().alias("Sensitivity_std"),
pl.col("sensitivity").median().alias("Sensitivity_median"),
pl.col("f_measure").mean().alias("F-measure_mean"),
pl.col("f_measure").std().alias("F-measure_std"),
pl.col("f_measure").median().alias("F-measure_median"),
]
)
return [
VariantStats(
precision=MetricStats(
mean=row["Precision_mean"],
std=row["Precision_std"],
median=row["Precision_median"],
),
sensitivity=MetricStats(
mean=row["Sensitivity_mean"],
std=row["Sensitivity_std"],
median=row["Sensitivity_median"],
),
f_measure=MetricStats(
mean=row["F-measure_mean"],
std=row["F-measure_std"],
median=row["F-measure_median"],
),
multiplexing=row["multiplexing"],
variant_type=variant_type,
complexity=complexity,
)
for row in stats_df.to_dicts()
]
stats_list = []
for variant_type, df in [("SNV", snv_df), ("Indel", indel_df)]:
for complexity in ["hc", "lc"]:
stats_list.extend(calculate_stats(df, variant_type, complexity))
return stats_list
def format_number_decimals(num: float) -> str:
"""Formats a number to three decimal places.
Args:
num (float): Number to format
Returns:
str: Formatted string with three decimal places
"""
return f"{num:.3f}"
def summarize_variant_stats(stats_list: List[VariantStats]) -> None:
"""Summarizes and prints variant statistics comparing singleplex and multiplex results.
Args:
stats_list (List[VariantStats]): List of VariantStats objects containing
calculated statistics
Prints:
Formatted summary of statistics including precision, sensitivity, and F-measure
for both singleplex and multiplex variants, along with percentage increases
"""
for variant_type in ["SNV", "Indel"]:
for complexity in ["hc", "lc"]:
print(f"\n{variant_type} Statistics ({complexity.upper()}):")
print("=" * 40)
variant_stats = [
stat
for stat in stats_list
if stat.variant_type == variant_type and stat.complexity == complexity
]
singleplex_stats = next(
stat for stat in variant_stats if stat.multiplexing == "singleplex"
)
multiplex_stats = next(
stat for stat in variant_stats if stat.multiplexing == "multiplex"
)
for metric_name, metric_pair in [
("Precision", (singleplex_stats.precision, multiplex_stats.precision)),
(
"Sensitivity",
(singleplex_stats.sensitivity, multiplex_stats.sensitivity),
),
("F-measure", (singleplex_stats.f_measure, multiplex_stats.f_measure)),
]:
singleplex_metric, multiplex_metric = metric_pair
print(f"\n{metric_name}:")
for stat_name in ["mean", "std", "median"]:
singleplex_val = getattr(singleplex_metric, stat_name)
multiplex_val = getattr(multiplex_metric, stat_name)
print(
f" {stat_name.capitalize():6s}: "
f"Singleplex: {format_number_decimals(singleplex_val)}, "
f"Multiplex: {format_number_decimals(multiplex_val)}"
)
increase = _calculate_percentage_increase(
singleplex_metric.mean, multiplex_metric.mean
)
print(
f" Mean Percentage Increase (Singleplex vs Multiplex): "
f"{increase:6.2f}%"
)
variant_multiplexing_stats = create_variant_multiplexing_stats(
snv_rtg_metrics_dfs["ont"], indel_rtg_metrics_dfs["ont"], nanoplot_qc_metrics_df
)
summarize_variant_stats(variant_multiplexing_stats)
SNV Statistics (HC): ======================================== Precision: Mean : Singleplex: 0.960, Multiplex: 0.944 Std : Singleplex: 0.003, Multiplex: 0.008 Median: Singleplex: 0.961, Multiplex: 0.947 Mean Percentage Increase (Singleplex vs Multiplex): 1.68% Sensitivity: Mean : Singleplex: 0.970, Multiplex: 0.936 Std : Singleplex: 0.004, Multiplex: 0.018 Median: Singleplex: 0.970, Multiplex: 0.941 Mean Percentage Increase (Singleplex vs Multiplex): 3.67% F-measure: Mean : Singleplex: 0.965, Multiplex: 0.940 Std : Singleplex: 0.003, Multiplex: 0.013 Median: Singleplex: 0.965, Multiplex: 0.945 Mean Percentage Increase (Singleplex vs Multiplex): 2.67% SNV Statistics (LC): ======================================== Precision: Mean : Singleplex: 0.788, Multiplex: 0.765 Std : Singleplex: 0.005, Multiplex: 0.009 Median: Singleplex: 0.789, Multiplex: 0.767 Mean Percentage Increase (Singleplex vs Multiplex): 2.92% Sensitivity: Mean : Singleplex: 0.747, Multiplex: 0.717 Std : Singleplex: 0.005, Multiplex: 0.015 Median: Singleplex: 0.747, Multiplex: 0.722 Mean Percentage Increase (Singleplex vs Multiplex): 4.27% F-measure: Mean : Singleplex: 0.767, Multiplex: 0.740 Std : Singleplex: 0.005, Multiplex: 0.012 Median: Singleplex: 0.767, Multiplex: 0.744 Mean Percentage Increase (Singleplex vs Multiplex): 3.62% Indel Statistics (HC): ======================================== Precision: Mean : Singleplex: 0.836, Multiplex: 0.720 Std : Singleplex: 0.028, Multiplex: 0.035 Median: Singleplex: 0.829, Multiplex: 0.734 Mean Percentage Increase (Singleplex vs Multiplex): 16.13% Sensitivity: Mean : Singleplex: 0.930, Multiplex: 0.873 Std : Singleplex: 0.009, Multiplex: 0.030 Median: Singleplex: 0.928, Multiplex: 0.882 Mean Percentage Increase (Singleplex vs Multiplex): 6.52% F-measure: Mean : Singleplex: 0.880, Multiplex: 0.789 Std : Singleplex: 0.019, Multiplex: 0.033 Median: Singleplex: 0.876, Multiplex: 0.801 Mean Percentage Increase (Singleplex vs Multiplex): 11.57% Indel Statistics (LC): ======================================== Precision: Mean : Singleplex: 0.458, Multiplex: 0.310 Std : Singleplex: 0.074, Multiplex: 0.019 Median: Singleplex: 0.425, Multiplex: 0.313 Mean Percentage Increase (Singleplex vs Multiplex): 47.53% Sensitivity: Mean : Singleplex: 0.490, Multiplex: 0.424 Std : Singleplex: 0.019, Multiplex: 0.024 Median: Singleplex: 0.483, Multiplex: 0.429 Mean Percentage Increase (Singleplex vs Multiplex): 15.66% F-measure: Mean : Singleplex: 0.472, Multiplex: 0.358 Std : Singleplex: 0.048, Multiplex: 0.021 Median: Singleplex: 0.452, Multiplex: 0.362 Mean Percentage Increase (Singleplex vs Multiplex): 31.70%
2. Impact of sequencing depth on variant calling¶
In [37]:
@dataclass
class PerformanceCorrelation:
"""Data class for storing performance correlation results."""
correlation: float
p_value: float
fit_params: Tuple[float, float, float]
confidence_intervals: np.ndarray
def asymptotic_func(x: np.ndarray, a: float, b: float, c: float) -> np.ndarray:
"""Calculate asymptotic function.
Args:
x: Input array
a: First parameter
b: Second parameter
c: Third parameter
Returns:
Calculated asymptotic values
"""
return a - b * np.exp(-c * x)
def calculate_correlation_stats(x: np.ndarray, y: np.ndarray) -> PerformanceCorrelation:
"""Calculate correlation statistics between two variables.
Args:
x: Independent variable array
y: Dependent variable array
Returns:
PerformanceCorrelation object containing correlation statistics
Raises:
ValueError: If curve fitting fails
"""
try:
# Calculate Pearson correlation
correlation, p_value = stats.pearsonr(x, y)
# Fit asymptotic curve
popt, pcov = curve_fit(
asymptotic_func, x, y, p0=[1, 0.1, 0.1], bounds=([0, 0, 0], [2, 1, 1])
)
# Calculate confidence intervals
perr = np.sqrt(np.diag(pcov))
n = len(x)
dof = max(0, n - len(popt))
t = stats.t.ppf(0.975, dof)
y_err = np.sqrt(np.sum((y - asymptotic_func(x, *popt)) ** 2) / dof)
x_range = np.linspace(x.min(), x.max(), 100)
ci = (
t
* y_err
* np.sqrt(
1 / n + (x_range - np.mean(x)) ** 2 / np.sum((x - np.mean(x)) ** 2)
)
)
return PerformanceCorrelation(
correlation=correlation,
p_value=p_value,
fit_params=tuple(popt),
confidence_intervals=ci,
)
except Exception as e:
raise ValueError(f"Error calculating correlation statistics: {str(e)}")
def plot_depth_vs_performance(
depth_df: pl.DataFrame,
snv_metrics_df: pl.DataFrame,
indel_metrics_df: pl.DataFrame,
metrics_df: pl.DataFrame,
figsize: Tuple[int, int] = (14, 16),
dpi: int = 300,
) -> plt.Figure:
"""Create plots comparing variant calling performance metrics against sequencing depth.
Args:
depth_df: DataFrame containing whole genome depth statistics
snv_metrics_df: DataFrame containing SNV calling metrics
indel_metrics_df: DataFrame containing indel calling metrics
metrics_df: DataFrame containing sample metadata
figsize: Figure size in inches
dpi: Figure resolution
Returns:
plt.Figure: Figure object containing the plots
Raises:
ValueError: If required columns are missing
"""
try:
metrics = ["precision", "sensitivity", "f_measure"]
variant_types = ["SNV", "Indel"]
complexities = ["hc", "lc"]
fig = plt.figure(figsize=figsize, dpi=dpi)
# Add height_ratios for padding between sections
gs = gridspec.GridSpec(5, 1, height_ratios=[1, 1, 0.05, 1, 1])
row_positions = {}
# Add section titles
fig.text(
0.5,
0.99,
"SNV Performance Metrics vs Whole Genome Mean Depth",
ha="center",
fontsize=12,
fontweight="bold",
)
fig.text(
0.5,
0.468,
"Indel Performance Metrics vs Whole Genome Mean Depth",
ha="center",
fontsize=12,
fontweight="bold",
)
for i, complexity in enumerate(complexities):
y_limits = {metric: (float("inf"), -float("inf")) for metric in metrics}
for j, (variant_type, variant_metrics_df) in enumerate(
zip(variant_types, [snv_metrics_df, indel_metrics_df])
):
# Adjust row index to account for padding
main_row_index = j * 3 + i if j == 1 else i
inner_gs = gridspec.GridSpecFromSubplotSpec(
1, 3, subplot_spec=gs[main_row_index]
)
data = (
variant_metrics_df.filter(pl.col("complexity") == complexity)
.join(
metrics_df.select(["sample", "multiplexing"]),
left_on="sample_id",
right_on="sample",
)
.join(
depth_df.select(["sample", "mean_depth"]),
left_on="sample_id",
right_on="sample",
)
)
for k, metric in enumerate(metrics):
ax = plt.subplot(inner_gs[k])
if k == 0:
row_positions[main_row_index] = ax.get_position()
ax.annotate(
chr(ord("A") + (j * 2 + i)),
xy=(-0.1, 1.05),
xycoords="axes fraction",
fontsize=12,
fontweight="bold",
)
# Add complexity label
complexity_label = (
"High Complexity"
if complexity == "hc"
else "Low Complexity"
)
ax.annotate(
complexity_label,
xy=(-0.18, 0.5),
xycoords="axes fraction",
fontweight="bold",
ha="right",
va="center",
rotation=90,
)
multiplexing_values = sorted(
data["multiplexing"].unique().to_list()
)
colors = sns.color_palette("colorblind", len(multiplexing_values))
color_mapping = dict(zip(multiplexing_values, colors))
for multiplex in multiplexing_values:
subset = data.filter(pl.col("multiplexing") == multiplex)
ax.scatter(
subset["mean_depth"].to_numpy(),
subset[metric].to_numpy(),
color=color_mapping[multiplex],
label=(
str(multiplex) if i == 0 and j == 0 and k == 0 else ""
),
s=100,
)
x = data["mean_depth"].to_numpy()
y = data[metric].to_numpy()
y_limits[metric] = (
min(y_limits[metric][0], y.min()),
max(y_limits[metric][1], y.max()),
)
corr_stats = calculate_correlation_stats(x, y)
x_range = np.linspace(x.min(), x.max(), 100)
y_fit = asymptotic_func(x_range, *corr_stats.fit_params)
ax.plot(
x_range,
y_fit,
color="gray",
linestyle="-",
linewidth=2,
label=(
"Line of Best Fit" if i == 0 and j == 0 and k == 0 else ""
),
)
ax.fill_between(
x_range,
y_fit - corr_stats.confidence_intervals,
y_fit + corr_stats.confidence_intervals,
color="gray",
alpha=0.2,
label=(
"95% Confidence Interval"
if i == 0 and j == 0 and k == 0
else ""
),
)
ax.set_title(
f"r={corr_stats.correlation:.2f}, p={corr_stats.p_value:.2e}"
)
ax.set_xlabel("Whole Genome Mean Depth")
ax.set_ylabel(metric.capitalize().replace("_", "-"))
if i == 0 and j == 0 and k == 0:
handles, labels = ax.get_legend_handles_labels()
legend = ax.legend(
handles=handles, loc="lower right", title="Multiplexing"
)
legend.get_title().set_weight("bold")
plt.tight_layout()
return fig
except Exception as e:
logger.error(f"Error creating depth vs performance plots: {str(e)}")
raise
performance_depth_fig = plot_depth_vs_performance(
depth_df=total_depth_df,
snv_metrics_df=snv_rtg_metrics_dfs["ont"],
indel_metrics_df=indel_rtg_metrics_dfs["ont"],
metrics_df=nanoplot_qc_metrics_df,
)
In [38]:
@dataclass
class AncovaResult:
"""Data class for storing ANCOVA analysis results."""
depth_effect: float
depth_ci_low: float
depth_ci_high: float
depth_pvalue: float
multiplex_effect: float
multiplex_ci_low: float
multiplex_ci_high: float
multiplex_pvalue: float
r_squared: float
adj_r_squared: float
def _prepare_ancova_data(
depth_df: pl.DataFrame, metrics_df: pl.DataFrame, np_metrics_df: pl.DataFrame
) -> pl.DataFrame:
"""Prepare data for ANCOVA analysis by merging relevant dataframes.
Args:
depth_df: DataFrame containing depth information
metrics_df: DataFrame containing metrics information
np_metrics_df: DataFrame containing nanoplot metrics information
Returns:
pl.DataFrame: Combined and processed DataFrame for ANCOVA analysis
"""
try:
depth_data = depth_df.group_by("sample").agg(
pl.col("mean_depth").mean().alias("wg_mean_depth")
)
# First join metrics with nanoplot data
data = metrics_df.join(
np_metrics_df.select(["sample", "multiplexing"]),
left_on="sample_id",
right_on="sample",
)
# Then join with depth data using sample_id
data = data.join(
depth_data,
left_on="sample_id",
right_on="sample",
)
data = data.with_columns(
pl.when(pl.col("multiplexing") == "multiplex")
.then(1)
.otherwise(0)
.alias("multiplexing_dummy")
)
return data
except Exception as e:
logger.error(f"Error preparing ANCOVA data: {str(e)}")
raise
def _get_ancova_statistical_significance(
pvalue: float, thresholds: Dict[str, float] = {"***": 0.001, "**": 0.01, "*": 0.05}
) -> str:
"""Determine statistical significance notation based on p-value.
Args:
pvalue: P-value from statistical test
thresholds: Dictionary of significance thresholds and their corresponding symbols,
default is standard thresholds (***: p<0.001, **: p<0.01, *: p<0.05)
Returns:
str: Significance stars ("***", "**", "*") or empty string if not significant
"""
try:
if not isinstance(pvalue, (int, float)):
raise ValueError(f"P-value must be numeric, got {type(pvalue)}")
if pvalue < 0 or pvalue > 1:
raise ValueError(f"P-value must be between 0 and 1, got {pvalue}")
# Sort thresholds by value in descending order to check most stringent first
sorted_thresholds = dict(sorted(thresholds.items(), key=lambda x: x[1]))
for symbol, threshold in sorted_thresholds.items():
if pvalue < threshold:
return symbol
return ""
except Exception as e:
logger.error(f"Error determining statistical significance: {str(e)}")
raise
def create_forest_plot_ancova(
snv_results: Dict[str, AncovaResult],
indel_results: Dict[str, AncovaResult],
figsize: Tuple[int, int] = (12, 7),
dpi: int = 300,
) -> plt.Figure:
"""Create forest plot for ANCOVA analysis results with colored effects and significance stars.
Args:
snv_results: Dictionary of SNV ANCOVA results
indel_results: Dictionary of INDEL ANCOVA results
figsize: Figure size (width, height)
dpi: Figure resolution
Returns:
plt.Figure: Generated matplotlib figure
"""
try:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize, dpi=dpi)
palette = sns.color_palette("colorblind")
def _add_results_to_plot(
ax: plt.Axes,
results: Dict[str, AncovaResult],
title: str,
show_legend: bool = True,
) -> None:
# Add panel label
ax.text(
-0.1,
1.1,
"A" if title == "SNVs" else "B",
transform=ax.transAxes,
fontsize=14,
fontweight="bold",
)
metrics = ["Precision", "Sensitivity", "F-measure"]
complexities = ["hc", "lc"]
# Prepare data structure
plot_data = []
labels = []
for metric in metrics:
for complexity in complexities:
key = f"{metric}_{complexity}"
plot_data.append((metric, complexity, results[key]))
labels.append(f"{metric} ({complexity.upper()})")
y_positions = np.arange(len(plot_data))
# Plot multiplexing and depth effects
for i, (_, _, result) in enumerate(plot_data):
# Multiplexing effect
ax.plot(
[result.multiplex_effect],
[y_positions[i]],
"o",
color=palette[0],
label="Multiplexing" if i == 0 else "",
)
ax.plot(
[result.multiplex_ci_low, result.multiplex_ci_high],
[y_positions[i], y_positions[i]],
"-",
color=palette[0],
)
stars = _get_ancova_statistical_significance(result.multiplex_pvalue)
if stars:
ax.text(
result.multiplex_effect,
y_positions[i] + 0.1,
stars,
ha="left",
va="center",
color=palette[0],
)
# Depth effect
ax.plot(
[result.depth_effect],
[y_positions[i]],
"o",
color=palette[1],
label="Depth" if i == 0 else "",
)
ax.plot(
[result.depth_ci_low, result.depth_ci_high],
[y_positions[i], y_positions[i]],
"-",
color=palette[1],
)
stars = _get_ancova_statistical_significance(result.depth_pvalue)
if stars:
ax.text(
result.depth_effect,
y_positions[i] + 0.1,
stars,
ha="left",
va="center",
color=palette[1],
)
# Customize plot
ax.axvline(x=0, color="gray", linestyle="--", alpha=0.5)
ax.set_yticks(y_positions)
ax.set_yticklabels(labels)
ax.set_title(f"{title}", fontweight="bold")
ax.set_xlabel("Effect Size (with 95% Confidence Intervals)")
if show_legend:
legend = ax.legend(loc="lower right", title="Metrics")
legend.get_title().set_fontweight("bold")
elif ax.get_legend():
ax.get_legend().remove()
# Create plots for SNVs and INDELs
_add_results_to_plot(ax1, snv_results, "SNVs", show_legend=True)
_add_results_to_plot(ax2, indel_results, "Indels", show_legend=False)
plt.tight_layout()
return fig
except Exception as e:
logger.error(f"Error creating forest plot: {str(e)}")
raise
def _perform_ancova(data: pl.DataFrame, metric: str, complexity: str) -> AncovaResult:
"""Perform Analysis of Covariance (ANCOVA) for a specific metric and complexity level.
Args:
data: Polars DataFrame containing the analysis data
metric: The metric to analyze ('Precision', 'Sensitivity', or 'F-measure')
complexity: The complexity level to analyze ('hc' or 'lc')
Returns:
AncovaResult: Dataclass containing the ANCOVA analysis results including:
- Effects, confidence intervals, and p-values for depth and multiplexing
- R-squared and adjusted R-squared values
Raises:
Exception: If there's an error during the ANCOVA analysis
"""
try:
# Create a mapping for column names
metric_mapping = {
"Precision": "precision",
"Sensitivity": "sensitivity",
"F-measure": "f_measure",
}
subset = data.filter(pl.col("complexity") == complexity)
subset = subset.with_columns(pl.col("wg_mean_depth").log().alias("log_depth"))
X = sm.add_constant(
subset.select(["log_depth", "multiplexing_dummy"]).to_numpy()
)
# Use the mapped column name
y = subset.select(metric_mapping[metric]).to_numpy().flatten()
model = sm.OLS(y, X).fit()
conf_int = model.conf_int()
return AncovaResult(
depth_effect=model.params[1],
depth_ci_low=conf_int[1, 0],
depth_ci_high=conf_int[1, 1],
depth_pvalue=model.pvalues[1],
multiplex_effect=model.params[2],
multiplex_ci_low=conf_int[2, 0],
multiplex_ci_high=conf_int[2, 1],
multiplex_pvalue=model.pvalues[2],
r_squared=model.rsquared,
adj_r_squared=model.rsquared_adj,
)
except Exception as e:
logger.error(f"Error performing ANCOVA analysis: {str(e)}")
raise
def print_ancova_results(
snv_results: Dict[str, AncovaResult], indel_results: Dict[str, AncovaResult]
) -> None:
"""Print formatted ANCOVA analysis results for SNVs and INDELs.
Args:
snv_results: Dictionary mapping metric_complexity to AncovaResult for SNVs
indel_results: Dictionary mapping metric_complexity to AncovaResult for INDELs
Output format:
- Results are grouped by variant type (SNV/INDEL)
- Each row shows metric, complexity, depth effect, and multiplexing effect
- Effects are displayed with confidence intervals and significance stars
- Significance levels: *** p<0.001, ** p<0.01, * p<0.05, ns = not significant
"""
def _format_effect(
effect: float, ci_low: float, ci_high: float, pvalue: float
) -> str:
stars = _get_ancova_statistical_significance(pvalue)
significance = f" {stars}" if stars else " (ns)"
return (
f"{effect:.3f} [{ci_low:.3f}, {ci_high:.3f}]{significance} "
f"(p={pvalue:.4f})"
)
def _print_variant_results(
results: Dict[str, AncovaResult], variant_type: str
) -> None:
print(f"\n{variant_type} Results:")
print("=" * 80)
print(
f"{'Metric':<15} {'Complexity':<10} {'Depth Effect':<35} {'Multiplexing Effect':<35}"
)
print("-" * 80)
metrics = ["Precision", "Sensitivity", "F-measure"]
complexities = ["hc", "lc"]
for metric in metrics:
for complexity in complexities:
key = f"{metric}_{complexity}"
result = results[key]
depth_effect = _format_effect(
result.depth_effect,
result.depth_ci_low,
result.depth_ci_high,
result.depth_pvalue,
)
multiplex_effect = _format_effect(
result.multiplex_effect,
result.multiplex_ci_low,
result.multiplex_ci_high,
result.multiplex_pvalue,
)
print(
f"{metric:<15} {complexity.upper():<10} {depth_effect:<35} "
f"{multiplex_effect:<35}"
)
print("-" * 80)
print("\nANCOVA Analysis Results")
print("=====================")
print("Significance levels: *** p<0.001, ** p<0.01, * p<0.05, ns = not significant")
_print_variant_results(snv_results, "SNV")
_print_variant_results(indel_results, "INDEL")
def run_ancova_analysis(
depth_df: pl.DataFrame,
snv_ont_metrics_df: pl.DataFrame,
indel_ont_metrics_df: pl.DataFrame,
np_metrics_df: pl.DataFrame,
) -> Tuple[Dict[str, AncovaResult], Dict[str, AncovaResult]]:
"""Execute ANCOVA analysis for SNV and Indel metrics.
Args:
depth_df: Polars DataFrame containing depth information
snv_ont_metrics_df: Polars DataFrame containing SNV metrics
indel_ont_metrics_df: Polars DataFrame containing INDEL metrics
np_metrics_df: Polars DataFrame containing nanoplot metrics
Returns:
Tuple[Dict[str, AncovaResult], Dict[str, AncovaResult]]: Two dictionaries containing
ANCOVA results for SNVs and INDELs respectively, with keys formatted as
'metric_complexity' (e.g., 'Precision_hc')
Raises:
Exception: If there's an error during any stage of the analysis
"""
try:
logger.info("Starting ANCOVA analysis")
snv_data = _prepare_ancova_data(depth_df, snv_ont_metrics_df, np_metrics_df)
indel_data = _prepare_ancova_data(depth_df, indel_ont_metrics_df, np_metrics_df)
metrics = ["Precision", "Sensitivity", "F-measure"]
complexities = ["hc", "lc"]
snv_results: Dict[str, AncovaResult] = {}
indel_results: Dict[str, AncovaResult] = {}
for metric in metrics:
for complexity in complexities:
key = f"{metric}_{complexity}"
logger.info(f"Analyzing {key}")
snv_results[key] = _perform_ancova(snv_data, metric, complexity)
indel_results[key] = _perform_ancova(indel_data, metric, complexity)
logger.info("ANCOVA analysis completed successfully")
return snv_results, indel_results
except Exception as e:
logger.error(f"Error in ANCOVA analysis: {str(e)}")
raise
snv_ancova_results, indel_ancova_results = run_ancova_analysis(
depth_df=total_depth_df,
snv_ont_metrics_df=snv_rtg_metrics_dfs["ont"],
indel_ont_metrics_df=indel_rtg_metrics_dfs["ont"],
np_metrics_df=nanoplot_qc_metrics_df,
)
print_ancova_results(snv_ancova_results, indel_ancova_results)
forest_plot_ancova = create_forest_plot_ancova(
snv_results=snv_ancova_results,
indel_results=indel_ancova_results,
)
__main__ - INFO - Starting ANCOVA analysis
__main__ - INFO - Analyzing Precision_hc
__main__ - INFO - Analyzing Precision_lc
__main__ - INFO - Analyzing Sensitivity_hc
__main__ - INFO - Analyzing Sensitivity_lc
__main__ - INFO - Analyzing F-measure_hc
__main__ - INFO - Analyzing F-measure_lc
__main__ - INFO - ANCOVA analysis completed successfully
ANCOVA Analysis Results ===================== Significance levels: *** p<0.001, ** p<0.01, * p<0.05, ns = not significant SNV Results: ================================================================================ Metric Complexity Depth Effect Multiplexing Effect -------------------------------------------------------------------------------- Precision HC 0.015 [0.001, 0.029] * (p=0.0341) -0.005 [-0.017, 0.006] (ns) (p=0.3128) Precision LC 0.022 [0.007, 0.038] ** (p=0.0088) -0.007 [-0.019, 0.005] (ns) (p=0.2279) -------------------------------------------------------------------------------- Sensitivity HC 0.037 [0.009, 0.065] * (p=0.0133) -0.009 [-0.031, 0.013] (ns) (p=0.3915) Sensitivity LC 0.034 [0.011, 0.056] ** (p=0.0072) -0.008 [-0.026, 0.010] (ns) (p=0.3735) -------------------------------------------------------------------------------- F-measure HC 0.026 [0.006, 0.046] * (p=0.0143) -0.007 [-0.023, 0.009] (ns) (p=0.3412) F-measure LC 0.028 [0.011, 0.046] ** (p=0.0038) -0.007 [-0.021, 0.006] (ns) (p=0.2608) -------------------------------------------------------------------------------- INDEL Results: ================================================================================ Metric Complexity Depth Effect Multiplexing Effect -------------------------------------------------------------------------------- Precision HC 0.131 [0.094, 0.168] *** (p=0.0000) -0.027 [-0.056, 0.003] (ns) (p=0.0717) Precision LC 0.252 [0.201, 0.304] *** (p=0.0000) 0.025 [-0.016, 0.066] (ns) (p=0.2099) -------------------------------------------------------------------------------- Sensitivity HC 0.065 [0.023, 0.108] ** (p=0.0063) -0.012 [-0.046, 0.022] (ns) (p=0.4423) Sensitivity LC 0.087 [0.063, 0.111] *** (p=0.0000) -0.007 [-0.026, 0.013] (ns) (p=0.4584) -------------------------------------------------------------------------------- F-measure HC 0.102 [0.063, 0.141] *** (p=0.0001) -0.021 [-0.052, 0.010] (ns) (p=0.1580) F-measure LC 0.174 [0.153, 0.194] *** (p=0.0000) 0.005 [-0.011, 0.022] (ns) (p=0.5005) --------------------------------------------------------------------------------
3. Impact of read length on variant calling¶
In [39]:
@dataclass
class PerformanceCorrelation:
"""Data class for storing performance correlation results."""
correlation: float
p_value: float
fit_params: Tuple[float, float]
confidence_intervals: np.ndarray
def plot_read_length_vs_performance(
snv_metrics_df: pl.DataFrame,
indel_metrics_df: pl.DataFrame,
metrics_df: pl.DataFrame,
figsize: Tuple[int, int] = (14, 16),
dpi: int = 300,
) -> plt.Figure:
"""Create plots comparing variant calling performance metrics against read length.
Args:
snv_metrics_df: DataFrame containing SNV calling metrics
indel_metrics_df: DataFrame containing indel calling metrics
metrics_df: DataFrame containing sample metadata with mean read lengths
figsize: Figure size in inches
dpi: Figure resolution
Returns:
plt.Figure: Figure object containing the plots
"""
try:
metrics = ["precision", "sensitivity", "f_measure"]
variant_types = ["SNV", "Indel"]
complexities = ["hc", "lc"]
fig = plt.figure(figsize=figsize, dpi=dpi)
gs = gridspec.GridSpec(5, 1, height_ratios=[1, 1, 0.05, 1, 1])
# Add section titles
fig.text(
0.5,
0.99,
"SNV Performance Metrics vs Mean Read Length",
ha="center",
fontsize=12,
fontweight="bold",
)
fig.text(
0.5,
0.468,
"Indel Performance Metrics vs Mean Read Length",
ha="center",
fontsize=12,
fontweight="bold",
)
for i, complexity in enumerate(complexities):
for j, (variant_type, variant_metrics_df) in enumerate(
zip(variant_types, [snv_metrics_df, indel_metrics_df])
):
main_row_index = j * 3 + i if j == 1 else i
inner_gs = gridspec.GridSpecFromSubplotSpec(
1, 3, subplot_spec=gs[main_row_index]
)
# Filter and join data
data = variant_metrics_df.filter(
pl.col("complexity") == complexity
).join(
metrics_df.select(["sample", "multiplexing", "mean_read_length"]),
left_on="sample_id",
right_on="sample",
)
if len(data) == 0:
logger.warning(
f"No data after joining for {variant_type} ({complexity})"
)
continue
for k, metric in enumerate(metrics):
ax = plt.subplot(inner_gs[k])
if k == 0:
ax.annotate(
chr(ord("A") + (j * 2 + i)),
xy=(-0.1, 1.05),
xycoords="axes fraction",
fontsize=12,
fontweight="bold",
)
complexity_label = (
"High Complexity"
if complexity == "hc"
else "Low Complexity"
)
ax.annotate(
complexity_label,
xy=(-0.18, 0.5),
xycoords="axes fraction",
fontweight="bold",
ha="right",
va="center",
rotation=90,
)
multiplexing_values = sorted(
data["multiplexing"].unique().to_list()
)
colors = sns.color_palette("colorblind", len(multiplexing_values))
color_mapping = dict(zip(multiplexing_values, colors))
for multiplex in multiplexing_values:
subset = data.filter(pl.col("multiplexing") == multiplex)
x = subset["mean_read_length"].to_numpy()
y = subset[metric].to_numpy()
ax.scatter(
x,
y,
color=color_mapping[multiplex],
label=(
str(multiplex) if i == 0 and j == 0 and k == 0 else ""
),
s=100,
)
x = data["mean_read_length"].to_numpy()
y = data[metric].to_numpy()
slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
x_range = np.linspace(x.min(), x.max(), 100)
y_fit = slope * x_range + intercept
ax.plot(
x_range,
y_fit,
color="gray",
linestyle="-",
linewidth=2,
label=(
"Line of Best Fit" if i == 0 and j == 0 and k == 0 else ""
),
)
n = len(x)
y_err = np.sqrt(
np.sum((y - (slope * x + intercept)) ** 2) / (n - 2)
)
ci = (
stats.t.ppf(0.975, n - 2)
* y_err
* np.sqrt(
1 / n
+ (x_range - np.mean(x)) ** 2
/ np.sum((x - np.mean(x)) ** 2)
)
)
ax.fill_between(
x_range,
y_fit - ci,
y_fit + ci,
color="gray",
alpha=0.2,
label=(
"95% Confidence Interval"
if i == 0 and j == 0 and k == 0
else ""
),
)
ax.set_title(f"r={r_value:.2f}, p={p_value:.2e}")
ax.set_xlabel("Mean Read Length")
ax.set_ylabel(metric.capitalize())
if i == 0 and j == 0 and k == 0:
handles, labels = ax.get_legend_handles_labels()
legend = ax.legend(
handles=handles,
loc="lower left",
title="Multiplexing",
)
legend.get_title().set_weight("bold")
plt.tight_layout()
return fig
except Exception as e:
logger.error(f"Error creating read length vs performance plots: {str(e)}")
raise
performance_readlength_fig = plot_read_length_vs_performance(
snv_metrics_df=snv_rtg_metrics_dfs["ont"],
indel_metrics_df=indel_rtg_metrics_dfs["ont"],
metrics_df=nanoplot_qc_metrics_df,
)
In [40]:
@dataclass
class SVMetrics:
"""Data class for structural variant metrics."""
type: str
length: Optional[int]
chrom: str
start: int
end: int
allele_idx: int
@dataclass
class SVAnalysisConfig:
"""Configuration for SV analysis."""
base_path: Path
technologies: Tuple[str, ...] = ("ont", "illumina")
def parse_int_or_first(value: Any) -> int:
"""
Parse integer value from various input types.
Args:
value: Input value that could be int, float, str, or tuple
Returns:
Parsed integer value
Raises:
ValueError: If value cannot be parsed as integer
"""
if isinstance(value, (int, float)):
return int(value)
elif isinstance(value, str):
return int(value.split("/")[0])
elif isinstance(value, tuple):
return int(value[0])
else:
raise ValueError(f"Unexpected value type: {type(value)}")
def handle_str(
record: Any, alt: str, chrom: str, start: int, end: int, alt_idx: int
) -> Optional[Dict[str, Any]]:
"""
Handle Short Tandem Repeat (STR) variants.
Args:
record: VCF record
alt: Alternative allele
chrom: Chromosome
start: Start position
end: End position
alt_idx: Alternative allele index
Returns:
Dictionary containing STR information or None if invalid
"""
try:
repcn = record.samples[0].get("REPCN")
if repcn is not None:
if isinstance(repcn, tuple):
repeat_count = parse_int_or_first(repcn[alt_idx])
else:
repeat_count = parse_int_or_first(repcn)
elif alt.startswith("<STR"):
str_alleles = record.alts
current_alt = str_alleles[alt_idx]
repeat_count = int(current_alt[4:-1])
else:
return None
ru = record.info.get("RU", "")
sv_len = repeat_count * len(ru)
return {
"type": "STR",
"length": sv_len,
"chrom": chrom,
"start": start,
"end": end,
}
except Exception as e:
logger.error(f"Error handling STR variant: {str(e)}")
return None
def handle_symbolic_allele(
record: Any, alt: str, chrom: str, start: int, end: int, sv_type: str, alt_idx: int
) -> Dict[str, Any]:
"""
Handle symbolic allele variants.
Args:
record: VCF record
alt: Alternative allele
chrom: Chromosome
start: Start position
end: End position
sv_type: Structural variant type
alt_idx: Alternative allele index
Returns:
Dictionary containing symbolic allele information
"""
sv_len = None
if sv_type == "INV" and "SVINSLEN" in record.info:
sv_len = record.info.get("SVINSLEN")
if isinstance(sv_len, tuple):
sv_len = sv_len[alt_idx] if len(sv_len) > alt_idx else sv_len[0]
if sv_len is None:
sv_len = record.info.get("SVLEN")
if isinstance(sv_len, tuple):
sv_len = sv_len[alt_idx] if len(sv_len) > alt_idx else sv_len[0]
if sv_len is None and sv_type == "INS":
left_seq = record.info.get("LEFT_SVINSSEQ", "")
right_seq = record.info.get("RIGHT_SVINSSEQ", "")
sv_len = len(left_seq) + len(right_seq)
if sv_len is None:
sv_len = end - start
return {
"type": sv_type,
"length": abs(sv_len) if sv_len is not None else None,
"chrom": chrom,
"start": start,
"end": end,
}
def handle_standard_sv(
record: Any, alt: str, chrom: str, start: int, end: int, sv_type: str, alt_idx: int
) -> Dict[str, Any]:
"""
Handle standard structural variants.
Args:
record: VCF record
alt: Alternative allele
chrom: Chromosome
start: Start position
end: End position
sv_type: Structural variant type
alt_idx: Alternative allele index
Returns:
Dictionary containing standard SV information
"""
if "SVLEN" in record.info:
sv_len = record.info["SVLEN"]
if isinstance(sv_len, tuple):
sv_len = sv_len[alt_idx] if len(sv_len) > alt_idx else sv_len[0]
elif sv_type == "INS":
left_seq = record.info.get("LEFT_SVINSSEQ", "")
right_seq = record.info.get("RIGHT_SVINSSEQ", "")
sv_len = len(left_seq) + len(right_seq)
else:
ref_len = len(record.ref)
alt_len = len(alt)
sv_len = alt_len - ref_len if sv_type == "INS" else ref_len - alt_len
return {
"type": sv_type,
"length": abs(sv_len),
"chrom": chrom,
"start": start,
"end": end,
}
def extract_sv_info(record: Any, alt: str, alt_idx: int) -> Optional[Dict[str, Any]]:
"""
Extract structural variant information from VCF record.
Args:
record: VCF record
alt: Alternative allele
alt_idx: Alternative allele index
Returns:
Dictionary containing SV information or None if invalid
"""
chrom = record.chrom
start = record.pos
end = record.info.get("END", start)
sv_type = record.info.get("SVTYPE", "Unknown")
if sv_type == "STR" or (isinstance(alt, str) and alt.startswith("<STR")):
return handle_str(record, alt, chrom, start, end, alt_idx)
elif isinstance(alt, str) and alt.startswith("<") and alt.endswith(">"):
return handle_symbolic_allele(record, alt, chrom, start, end, sv_type, alt_idx)
else:
return handle_standard_sv(record, alt, chrom, start, end, sv_type, alt_idx)
def read_sv_vcf_file(file_path: Path) -> pl.DataFrame:
"""
Read and parse structural variant VCF file.
Args:
file_path: Path to the VCF file
Returns:
Polars DataFrame containing parsed SV information
Raises:
FileNotFoundError: If VCF file does not exist
ValueError: If VCF file is malformed
"""
try:
svs: List[Dict] = []
with pysam.VariantFile(file_path) as vcf:
for record in vcf:
for alt_idx, alt in enumerate(record.alts):
sv_info = extract_sv_info(record, alt, alt_idx)
if sv_info:
sv_info["allele_idx"] = alt_idx
svs.append(sv_info)
return pl.DataFrame(svs)
except FileNotFoundError:
logger.error(f"VCF file not found: {file_path}")
raise
except Exception as e:
logger.error(f"Error reading VCF file {file_path}: {str(e)}")
raise ValueError(f"Error parsing VCF file: {str(e)}")
def analyze_sv_calls(
sample_id: str,
ont_id: str,
illumina_id: str,
base_path: Path = Path("/scratch/prj/ppn_als_longread/ont-benchmark"),
) -> pl.DataFrame:
"""
Analyze structural variant calls from ONT and Illumina data.
Args:
sample_id: Sample identifier
ont_id: ONT sample identifier
illumina_id: Illumina sample identifier
base_path: Base path for input files
Returns:
Polars DataFrame containing combined SV analysis results
Raises:
FileNotFoundError: If required input files are missing
"""
try:
sv_path = base_path / "output/sv/survivor"
ont_file = sv_path / ont_id / f"{ont_id}.ont.sv_str.filtered.vcf"
illumina_file = sv_path / ont_id / f"{illumina_id}.illumina.sv.filtered.vcf"
merged_file = sv_path / ont_id / f"{ont_id}_{illumina_id}_merged.vcf"
ont_svs = read_sv_vcf_file(ont_file)
illumina_svs = read_sv_vcf_file(illumina_file)
merged_svs = read_sv_vcf_file(merged_file)
# Create unique identifiers
ont_svs = ont_svs.with_columns(
pl.concat_str(
pl.col("chrom"),
pl.col("start"),
pl.col("end"),
pl.col("type"),
separator="_",
).alias("sv_id")
)
illumina_svs = illumina_svs.with_columns(
pl.concat_str(
pl.col("chrom"),
pl.col("start"),
pl.col("end"),
pl.col("type"),
separator="_",
).alias("sv_id")
)
merged_svs = merged_svs.with_columns(
pl.concat_str(
pl.col("chrom"),
pl.col("start"),
pl.col("end"),
pl.col("type"),
separator="_",
).alias("sv_id")
)
# Combine and mark sources
all_svs = pl.concat([ont_svs, illumina_svs]).unique(subset="sv_id")
all_svs = all_svs.with_columns(
[
pl.col("sv_id").is_in(ont_svs["sv_id"]).alias("ONT"),
pl.col("sv_id").is_in(illumina_svs["sv_id"]).alias("Illumina"),
pl.col("sv_id").is_in(merged_svs["sv_id"]).alias("Merged"),
pl.lit(sample_id).alias("sample_id"),
]
)
return all_svs.drop("sv_id")
except Exception as e:
logger.error(f"Error analyzing SV calls for sample {sample_id}: {str(e)}")
raise
sv_data_list = []
for row in sample_ids.iter_rows(named=True):
try:
sample_data = analyze_sv_calls(
sample_id=row["ont_id"], ont_id=row["ont_id"], illumina_id=row["lp_id"]
)
sv_data_list.append(sample_data)
except Exception as e:
logger.error(f"Failed to process sample {row['ont_id']}: {str(e)}")
continue
sv_data_df = pl.concat(sv_data_list)
logger.info(f"Total SV calls processed: {sv_data_df.height}")
sv_data_df
__main__ - INFO - Total SV calls processed: 508922
Out[40]:
shape: (508_922, 10)
| type | length | chrom | start | end | allele_idx | ONT | Illumina | Merged | sample_id |
|---|---|---|---|---|---|---|---|---|---|
| str | i64 | str | i64 | i64 | i64 | bool | bool | bool | str |
| "INS" | 80 | "chr7" | 50802725 | 50802725 | 0 | true | false | false | "A046_12" |
| "DEL" | 61 | "chr7" | 48848787 | 48848787 | 0 | false | true | true | "A046_12" |
| "DEL" | 54 | "chr1" | 101657878 | 101657878 | 0 | true | false | false | "A046_12" |
| "INS" | 31 | "chr11" | 55765038 | 55765038 | 0 | true | false | false | "A046_12" |
| "INS" | 132 | "chr8" | 535731 | 535731 | 0 | true | false | false | "A046_12" |
| … | … | … | … | … | … | … | … | … | … |
| "DEL" | 673 | "chr16" | 32507512 | 32507512 | 0 | false | true | false | "A162_09" |
| "DEL" | 205 | "chr16" | 88122050 | 88122050 | 0 | false | true | true | "A162_09" |
| "INS" | 512 | "chr17" | 26701545 | 26701545 | 0 | true | false | false | "A162_09" |
| "INS" | 54 | "chr3" | 76698102 | 76698102 | 0 | true | false | false | "A162_09" |
| "DEL" | 165 | "chrX" | 364576 | 364576 | 0 | true | false | false | "A162_09" |
In [41]:
def compare_sv_counts(sv_data_df: pl.DataFrame) -> pl.DataFrame:
"""
Calculate SV counts across different technologies for each sample.
Args:
sv_data_df: Input DataFrame containing SV data with columns:
sample_id, ONT, Illumina, and Merged
Returns:
DataFrame with anonymized sample IDs and counts for each technology
Raises:
ValueError: If required columns are missing
"""
try:
required_cols = {"sample_id", "ONT", "Illumina", "Merged"}
if not all(col in sv_data_df.columns for col in required_cols):
missing = required_cols - set(sv_data_df.columns)
raise ValueError(f"Missing required columns: {missing}")
counts = (
sv_data_df.group_by("sample_id")
.agg(
[
pl.col("ONT").sum().alias("long-read"),
pl.col("Illumina").sum().alias("short-read"),
pl.col("Merged").sum().alias("consensus"),
]
)
.sort("sample_id")
)
# Add anonymized sample IDs
counts = counts.with_columns(
[
pl.Series(
name="anonymised_sample",
values=[f"Sample {i+1}" for i in range(counts.height)],
)
]
)
return counts.select(
["anonymised_sample", "long-read", "short-read", "consensus"]
)
except Exception as e:
logger.error(f"Error comparing SV counts: {str(e)}")
raise
def plot_sv_counts(
sv_counts_df: pl.DataFrame,
figsize: Tuple[int, int] = (12, 6),
dpi: int = 300,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Plot structural variant counts across different technologies for each sample.
Args:
sv_counts_df: DataFrame containing SV count data
figsize: Figure size as (width, height)
dpi: Figure resolution
gs: Optional GridSpec for plotting within a larger figure
Returns:
Figure object if created independently (gs=None)
Raises:
ValueError: If required columns are missing
"""
try:
required_cols = {"anonymised_sample", "long-read", "short-read", "consensus"}
if not all(col in sv_counts_df.columns for col in required_cols):
missing = required_cols - set(sv_counts_df.columns)
raise ValueError(f"Missing required columns: {missing}")
if gs is None:
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
ax = fig.add_subplot(gs[0, 0])
plot_data = sv_counts_df.unpivot(
index=["anonymised_sample"],
on=["long-read", "short-read", "consensus"],
variable_name="Technology",
value_name="Count",
)
sns.barplot(
data=plot_data, x="anonymised_sample", y="Count", hue="Technology", ax=ax
)
ax.set_title("SV Call Counts by Sample and Platform")
ax.set_xlabel("Sample")
ax.set_ylabel("Number of SV Calls")
legend = ax.legend(title="Technology")
legend.get_title().set_weight("bold")
locs, labels = ax.get_xticks(), ax.get_xticklabels()
ax.set_xticks([loc + 0.1 for loc in locs])
for tick in ax.get_xticklabels():
tick.set_rotation(45)
tick.set_ha("right")
if gs is None:
plt.tight_layout()
return fig
else:
return None
except Exception as e:
logger.error(f"Error plotting SV counts: {str(e)}")
raise
sv_counts_df = compare_sv_counts(sv_data_df)
sv_counts_plot = plot_sv_counts(sv_counts_df)
In [42]:
def calculate_consensus_percentages(sv_counts_df: pl.DataFrame) -> pl.DataFrame:
"""
Calculate consensus percentages between ONT/Illumina and consensus calls.
Args:
sv_counts_df: DataFrame containing SV count data with columns:
anonymised_sample, long-read, short-read, and consensus
Returns:
DataFrame with added consensus percentage columns and printed statistics
Raises:
ValueError: If required columns are missing
"""
try:
required_cols = {"long-read", "short-read", "consensus"}
if not all(col in sv_counts_df.columns for col in required_cols):
missing = required_cols - set(sv_counts_df.columns)
raise ValueError(f"Missing required columns: {missing}")
result_df = sv_counts_df.with_columns(
[
(pl.col("consensus") / pl.col("long-read") * 100)
.fill_null(0)
.alias("ONT_Consensus_Percent"),
(pl.col("consensus") / pl.col("short-read") * 100)
.fill_null(0)
.alias("Illumina_Consensus_Percent"),
]
)
# Calculate statistics
ont_stats = result_df.select(
[
pl.col("ONT_Consensus_Percent").mean().alias("mean"),
pl.col("ONT_Consensus_Percent").std().alias("std"),
]
).row(0)
illumina_stats = result_df.select(
[
pl.col("Illumina_Consensus_Percent").mean().alias("mean"),
pl.col("Illumina_Consensus_Percent").std().alias("std"),
]
).row(0)
logger.info(
f"Average consensus percentage for ONT: "
f"{ont_stats[0]:.2f}% ± {ont_stats[1]:.2f}% (mean ± SD)"
)
logger.info(
f"Average consensus percentage for Illumina: "
f"{illumina_stats[0]:.2f}% ± {illumina_stats[1]:.2f}% (mean ± SD)"
)
return result_df
except Exception as e:
logger.error(f"Error calculating consensus percentages: {str(e)}")
raise
sv_consensus_df = calculate_consensus_percentages(sv_counts_df)
__main__ - INFO - Average consensus percentage for ONT: 20.61% ± 1.60% (mean ± SD)
__main__ - INFO - Average consensus percentage for Illumina: 57.87% ± 10.11% (mean ± SD)
In [43]:
def calculate_average_difference(sv_counts_df: pl.DataFrame) -> pl.DataFrame:
"""
Calculate average ratio difference between ONT and Illumina SV counts.
Args:
sv_counts_df: DataFrame containing SV count data with long-read and short-read columns
Returns:
DataFrame with added ONT/Illumina ratio column
Raises:
ValueError: If required columns are missing
"""
try:
required_cols = {"long-read", "short-read"}
if not all(col in sv_counts_df.columns for col in required_cols):
missing = required_cols - set(sv_counts_df.columns)
raise ValueError(f"Missing required columns: {missing}")
result_df = sv_counts_df.with_columns(
(pl.col("long-read") / pl.col("short-read"))
.fill_null(0)
.alias("ONT_Illumina_Ratio")
)
# Calculate statistics
stats = result_df.select(
[
pl.col("ONT_Illumina_Ratio").mean().alias("mean"),
pl.col("ONT_Illumina_Ratio").std().alias("std"),
]
).row(0)
logger.info(
f"Average ratio of SV counts between ONT and Illumina: "
f"{stats[0]:.2f} ± {stats[1]:.2f} (mean ± SD)"
)
return result_df
except Exception as e:
logger.error(f"Error calculating average difference: {str(e)}")
raise
average_diff_df = calculate_average_difference(sv_counts_df)
__main__ - INFO - Average ratio of SV counts between ONT and Illumina: 2.86 ± 0.70 (mean ± SD)
2. SV Size Distribution¶
In [44]:
def plot_sv_size_distributions(
sv_data_df: pl.DataFrame,
figsize: Tuple[int, int] = (12, 6),
dpi: int = 300,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Plot length distributions of structural variants for ONT and Illumina data.
Args:
sv_data_df: DataFrame containing SV data with columns 'length', 'ONT', and 'Illumina'
figsize: Figure size as (width, height)
dpi: Figure resolution
gs: Optional GridSpec for plotting within a larger figure
Returns:
Figure object if created independently (gs=None)
Raises:
ValueError: If required columns are missing or data is invalid
"""
try:
required_cols = {"length", "ONT", "Illumina"}
if not all(col in sv_data_df.columns for col in required_cols):
missing = required_cols - set(sv_data_df.columns)
raise ValueError(f"Missing required columns: {missing}")
if gs is None:
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
ax = fig.add_subplot(gs[0, 0])
# Extract and clean length data for each technology
ont_lengths = (
sv_data_df.filter(pl.col("ONT"))
.select("length")
.drop_nulls()
.filter(pl.col("length").is_finite())
.get_column("length")
.to_list()
)
illumina_lengths = (
sv_data_df.filter(pl.col("Illumina"))
.select("length")
.drop_nulls()
.filter(pl.col("length").is_finite())
.get_column("length")
.to_list()
)
# Plot distributions
sns.histplot(
ont_lengths,
log_scale=True,
bins=50,
stat="density",
kde=True,
alpha=0.7,
label="long-read",
ax=ax,
)
sns.histplot(
illumina_lengths,
log_scale=True,
bins=50,
stat="density",
kde=True,
alpha=0.7,
label="short-read",
ax=ax,
)
ax.set_title("SV Size Distribution")
ax.set_xlabel("SV Size (bp)")
ax.set_ylabel("Density")
legend = ax.legend(title="Technology")
legend.get_title().set_weight("bold")
if gs is None:
plt.tight_layout()
return fig
else:
return None
except Exception as e:
logger.error(f"Error plotting SV length distributions: {str(e)}")
raise
sv_size_plot = plot_sv_size_distributions(sv_data_df)
In [45]:
@dataclass
class SVSizeStats:
"""Data class for structural variant size statistics."""
maximum: float
minimum: float
mean: float
std: float
median: float
def calculate_sv_size_stats(lengths: List[float]) -> SVSizeStats:
"""
Calculate statistical measures for structural variant sizes.
Args:
lengths: List of SV lengths to analyze
Returns:
SVSizeStats object containing calculated statistics
Raises:
ValueError: If input list is empty
"""
try:
if not lengths:
raise ValueError("Empty length list provided")
return SVSizeStats(
maximum=float(np.max(lengths)),
minimum=float(np.min(lengths)),
mean=float(np.mean(lengths)),
std=float(np.std(lengths)),
median=float(np.median(lengths)),
)
except Exception as e:
logger.error(f"Error calculating SV size statistics: {str(e)}")
raise
def format_number(num: Union[int, float]) -> str:
"""
Format numbers with appropriate separators and decimal places.
Args:
num: Number to format
Returns:
Formatted string representation of the number
"""
try:
if isinstance(num, (int, np.integer)):
return f"{num:,d}"
elif isinstance(num, float):
return f"{num:,.2f}"
return str(num)
except Exception as e:
logger.error(f"Error formatting number {num}: {str(e)}")
return str(num)
def analyze_sv_size_distributions(sv_data_df: pl.DataFrame) -> Dict[str, SVSizeStats]:
"""
Analyze size distributions of structural variants across technologies.
Args:
sv_data_df: DataFrame containing SV data with length and technology columns
Returns:
Dictionary mapping technology names to their SVSizeStats
Raises:
ValueError: If required columns are missing
"""
try:
required_cols = {"length", "ONT", "Illumina"}
if not all(col in sv_data_df.columns for col in required_cols):
missing = required_cols - set(sv_data_df.columns)
raise ValueError(f"Missing required columns: {missing}")
# Extract lengths for each technology
ont_lengths = (
sv_data_df.filter(pl.col("ONT"))
.select("length")
.filter(pl.col("length").is_finite())
.to_series()
.to_list()
)
illumina_lengths = (
sv_data_df.filter(pl.col("Illumina"))
.select("length")
.filter(pl.col("length").is_finite())
.to_series()
.to_list()
)
# Calculate statistics
stats = {
"ONT": calculate_sv_size_stats(ont_lengths),
"Illumina": calculate_sv_size_stats(illumina_lengths),
}
# Print results
for tech, tech_stats in stats.items():
print(f"\n{tech} SV Length Statistics:")
print("=" * 40)
for stat_name, value in tech_stats.__dict__.items():
print(f" {stat_name.capitalize():6s}: {format_number(value)}")
return stats
except Exception as e:
logger.error(f"Error analyzing SV size distributions: {str(e)}")
raise
sv_size_dist = analyze_sv_size_distributions(sv_data_df)
ONT SV Length Statistics: ======================================== Maximum: 129,371,498.00 Minimum: 12.00 Mean : 3,527.59 Std : 482,253.95 Median: 80.00 Illumina SV Length Statistics: ======================================== Maximum: 6,064.00 Minimum: 2.00 Mean : 165.62 Std : 164.47 Median: 96.00
3. SV Types¶
In [46]:
@dataclass
class SVTypeStats:
"""Data class for structural variant type statistics."""
mean: float
median: float
std_dev: float
def calculate_sv_type_stats(counts: List[int]) -> SVTypeStats:
"""
Calculate statistics for SV type counts.
Args:
counts: List of counts for a specific SV type
Returns:
SVTypeStats object containing calculated statistics
Raises:
ValueError: If input list is empty
"""
try:
if not counts:
raise ValueError("Empty counts list provided")
return SVTypeStats(
mean=float(np.mean(counts)),
median=float(np.median(counts)),
std_dev=float(np.std(counts)),
)
except Exception as e:
print(f"Error calculating SV type statistics: {str(e)}")
raise
def analyze_sv_types(sv_data_df: pl.DataFrame) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""
Analyze distribution of structural variant types across samples and technologies.
Args:
sv_data_df: DataFrame containing SV data with type, sample_id, and technology columns
Returns:
Tuple containing:
- DataFrame with SV type counts per sample and platform
- DataFrame with statistical summary of SV types across platforms
Raises:
ValueError: If required columns are missing
"""
try:
required_cols = {"type", "sample_id", "ONT", "Illumina"}
if not all(col in sv_data_df.columns for col in required_cols):
missing = required_cols - set(sv_data_df.columns)
raise ValueError(f"Missing required columns: {missing}")
# Initialize data structures
type_counts: Dict[Tuple[str, str], Dict[str, int]] = defaultdict(dict)
type_stats: Dict[str, Dict[str, List[int]]] = defaultdict(
lambda: defaultdict(list)
)
# Process each sample and platform
for sample_id in sv_data_df.get_column("sample_id").unique():
sample_data = sv_data_df.filter(pl.col("sample_id") == sample_id)
for platform in ["ONT", "Illumina"]:
platform_data = sample_data.filter(pl.col(platform))
# Calculate type counts
type_counts_dict = (
platform_data.get_column("type")
.value_counts()
.sort("count", descending=True)
.to_dict(as_series=False)
)
# Store counts
for sv_type, count in zip(
type_counts_dict["type"], type_counts_dict["count"]
):
type_counts[(sample_id, platform.lower())][sv_type] = count
type_stats[platform.lower()][sv_type].append(count)
# Create counts DataFrame
counts_data = [
{"sample_id": sample_id, "platform": platform, **counts}
for (sample_id, platform), counts in type_counts.items()
]
df_sv_type_counts = pl.DataFrame(counts_data)
# Calculate statistics
stats_data = [
{
"platform": platform,
"sv_type": sv_type,
**calculate_sv_type_stats(counts).__dict__,
}
for platform, sv_types in type_stats.items()
for sv_type, counts in sv_types.items()
]
df_sv_type_stats = pl.DataFrame(stats_data)
# Print summary
print("\nSV Type Counts Summary:")
print("=" * 40)
with pl.Config(tbl_rows=len(df_sv_type_counts)):
display(df_sv_type_counts)
print("\nSV Type Statistics Summary:")
print("=" * 40)
display(df_sv_type_stats)
return df_sv_type_counts, df_sv_type_stats
except Exception as e:
print(f"Error analyzing SV types: {str(e)}")
raise
sv_type_counts_df, sv_type_stats_df = analyze_sv_types(sv_data_df)
SV Type Counts Summary: ========================================
shape: (28, 8)
| sample_id | platform | INS | DEL | STR | BND | INV | DUP |
|---|---|---|---|---|---|---|---|
| str | str | i64 | i64 | i64 | i64 | i64 | i64 |
| "A097_92" | "ont" | 12668 | 10524 | 23 | 20 | 13 | 8 |
| "A097_92" | "illumina" | 3260 | 5578 | null | 765 | 1 | null |
| "A157_02" | "ont" | 16682 | 13597 | 33 | 75 | 21 | 17 |
| "A157_02" | "illumina" | 3258 | 5560 | null | 788 | 1 | null |
| "A079_07" | "ont" | 11587 | 9365 | 18 | 15 | 13 | 5 |
| "A079_07" | "illumina" | 3353 | 5539 | null | 725 | 1 | null |
| "A154_06" | "ont" | 17361 | 14081 | 34 | 66 | 26 | 9 |
| "A154_06" | "illumina" | 3059 | 5170 | null | 666 | null | null |
| "A153_01" | "ont" | 16330 | 13465 | 26 | 63 | 21 | 10 |
| "A153_01" | "illumina" | 3222 | 5463 | null | 801 | 1 | null |
| "A081_91" | "ont" | 8978 | 7301 | 11 | 10 | 8 | 5 |
| "A081_91" | "illumina" | 3472 | 5644 | null | 811 | 1 | null |
| "A153_06" | "ont" | 18705 | 15166 | 33 | 98 | 32 | 18 |
| "A153_06" | "illumina" | 3088 | 5429 | null | 702 | 1 | null |
| "A154_04" | "ont" | 16385 | 13348 | 33 | 58 | 27 | 8 |
| "A154_04" | "illumina" | 3233 | 5670 | null | 734 | 1 | null |
| "A085_00" | "ont" | 9662 | 7898 | 12 | 17 | 8 | 5 |
| "A085_00" | "illumina" | 3050 | 5313 | null | 732 | 1 | null |
| "A162_09" | "ont" | 19332 | 15542 | 36 | 104 | 40 | 26 |
| "A162_09" | "illumina" | 3274 | 5493 | null | 742 | 1 | null |
| "A048_09" | "ont" | 13394 | 11160 | 21 | 28 | 21 | 6 |
| "A048_09" | "illumina" | 3435 | 5715 | null | 769 | null | null |
| "A046_12" | "ont" | 11529 | 9625 | 18 | 16 | 19 | 8 |
| "A046_12" | "illumina" | 3262 | 5392 | null | 675 | null | null |
| "A149_01" | "ont" | 15259 | 12468 | 28 | 50 | 22 | 12 |
| "A149_01" | "illumina" | 3083 | 5425 | null | 694 | 1 | null |
| "A160_96" | "ont" | 19113 | 15504 | 33 | 124 | 36 | 17 |
| "A160_96" | "illumina" | 3172 | 5491 | null | 773 | 1 | null |
SV Type Statistics Summary: ========================================
shape: (10, 5)
| platform | sv_type | mean | median | std_dev |
|---|---|---|---|---|
| str | str | f64 | f64 | f64 |
| "ont" | "INS" | 14784.642857 | 15794.5 | 3352.940769 |
| "ont" | "DEL" | 12074.571429 | 12908.0 | 2675.388322 |
| "ont" | "STR" | 25.642857 | 27.0 | 8.21677 |
| "ont" | "BND" | 53.142857 | 54.0 | 35.981855 |
| "ont" | "INV" | 21.928571 | 21.0 | 9.361504 |
| "ont" | "DUP" | 11.0 | 8.5 | 6.047432 |
| "illumina" | "DEL" | 5491.571429 | 5492.0 | 139.839047 |
| "illumina" | "INS" | 3230.071429 | 3245.5 | 127.395035 |
| "illumina" | "BND" | 741.214286 | 738.0 | 44.070039 |
| "illumina" | "INV" | 1.0 | 1.0 | 0.0 |
In [47]:
def plot_sv_types(
sv_data_df: pl.DataFrame,
figsize: Tuple[int, int] = (12, 8),
dpi: int = 300,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Plot structural variant types by sample for both long-read and short-read data.
Args:
sv_data_df: DataFrame containing SV data with columns 'sample_id', 'type', 'ONT', and 'Illumina'
figsize: Figure size as (width, height)
dpi: Figure resolution
gs: Optional GridSpec for plotting within a larger figure
Returns:
Figure object if created independently (gs=None)
Raises:
ValueError: If required columns are missing or data is invalid
"""
try:
required_cols = {"sample_id", "type", "ONT", "Illumina"}
if not all(col in sv_data_df.columns for col in required_cols):
missing = required_cols - set(sv_data_df.columns)
raise ValueError(f"Missing required columns: {missing}")
# Create figure and subplots
if gs is None:
fig = plt.figure(figsize=figsize, dpi=dpi)
gs_local = gridspec.GridSpec(2, 1, figure=fig)
else:
fig = plt.gcf()
gs_local = gs
# Platform configuration
platform_config = {
"ONT": {"title": "Long-read", "position": 0},
"Illumina": {"title": "Short-read", "position": 1},
}
# Map sample IDs to anonymised sample IDs using nanoplot_qc_metrics_df
sample_map = dict(
zip(
nanoplot_qc_metrics_df.get_column("sample").to_list(),
nanoplot_qc_metrics_df.get_column("anonymised_sample").to_list(),
)
)
# Compute globally sorted variant types (alphabetically)
all_variant_types = sorted(sv_data_df.get_column("type").unique().to_list())
# Loop over platforms
for platform, config in platform_config.items():
rows, cols = gs_local.get_geometry()
if rows == 1:
ax = fig.add_subplot(gs_local[0, config["position"]])
else:
ax = fig.add_subplot(gs_local[config["position"], 0])
# Filter and prepare data for platform using Polars
platform_data = (
sv_data_df.filter(pl.col(platform))
.with_columns(
pl.col("sample_id").replace(sample_map).alias("anonymised_sample")
)
.group_by(["anonymised_sample", "type"])
.agg(pl.len().alias("count"))
.pivot(values="count", index="anonymised_sample", on="type")
.fill_null(0)
.with_columns(
pl.col("anonymised_sample")
.str.extract(r"(\d+)$")
.cast(pl.Int64)
.alias("sample_number")
)
.sort("sample_number")
.drop("sample_number")
)
# Ensure all variant type columns exist and are ordered alphabetically
for vt in all_variant_types:
if vt not in platform_data.columns:
platform_data = platform_data.with_columns(pl.lit(0).alias(vt))
platform_data = platform_data.select(
["anonymised_sample"] + all_variant_types
)
# Plot stacked bar chart directly using Matplotlib
x = platform_data.get_column("anonymised_sample").to_list()
bottom = np.zeros(len(x))
for vt in all_variant_types:
counts = platform_data.get_column(vt).to_numpy()
ax.bar(x, counts, bottom=bottom, label=vt)
bottom += counts
ax.set_title(f"SV Types by Sample - {config['title']}")
ax.set_xlabel("Sample")
ax.set_ylabel("Number of SVs")
plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
locs, labels = ax.get_xticks(), ax.get_xticklabels()
ax.set_xticks([loc + 0.2 for loc in locs])
if gs is None and config["position"] == 0:
legend = ax.legend(title="SV Type", bbox_to_anchor=(1, 1))
legend.get_title().set_weight("bold")
elif gs is not None and config["position"] == 1:
legend = ax.legend(title="SV Type", bbox_to_anchor=(1.05, 1.05))
legend.get_title().set_weight("bold")
if gs is None:
plt.tight_layout()
return fig
return None
except Exception as e:
logger.error(f"Error plotting SV types: {str(e)}")
raise
sv_types_plot = plot_sv_types(sv_data_df)
Combined Plots¶
In [48]:
def create_combined_sv_analysis_plot(
sv_data_df: pl.DataFrame,
sv_counts_df: pl.DataFrame,
figsize: Tuple[int, int] = (12, 8),
dpi: int = 300,
) -> plt.Figure:
"""
Create a combined figure showing SV analysis plots in 2x2 grid.
Args:
sv_data_df: DataFrame containing structural variant data
sv_counts_df: DataFrame containing SV count data
figsize: Figure size as (width, height)
dpi: Figure resolution
Returns:
Combined figure object
"""
try:
fig = plt.figure(figsize=figsize, dpi=dpi)
gs = fig.add_gridspec(2, 2)
# Plot SV Calls per Platform (A)
plot_sv_counts(
sv_counts_df,
gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[0, 0]),
)
# Plot SV Size Distribution (B)
plot_sv_size_distributions(
sv_data_df, gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[0, 1])
)
# Plot Long Read and Short Read SV Types (C & D)
plot_sv_types(
sv_data_df, gs=gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs[1, :])
)
# Add panel labels
for i, label in enumerate(["A", "B", "C", "D"]):
ax = fig.axes[i]
ax.text(
-0.1,
1.05,
label,
transform=ax.transAxes,
fontsize=12,
fontweight="bold",
va="top",
)
fig.set_constrained_layout(True)
return fig
except Exception as e:
logger.error(f"Error creating combined SV plot: {str(e)}")
raise
combined_sv_analysis_plot = create_combined_sv_analysis_plot(sv_data_df, sv_counts_df)
In [49]:
@dataclass
class SVTypeConfig:
"""Configuration for SV type display names and plotting settings."""
name: str
use_log_scale: bool = True
def plot_sv_size_distribution_by_type(
sv_data_df: pl.DataFrame,
figsize: Tuple[int, int] = (12, 8),
dpi: int = 300,
) -> plt.Figure:
"""
Create a combined figure showing size distribution of structural variants by type
for both long-read and short-read data in a 3x2 grid.
Args:
sv_data_df: DataFrame containing SV data with columns 'type', 'length', 'ONT', and 'Illumina'
figsize: Figure size as (width, height)
dpi: Figure resolution
Returns:
Combined figure object
Raises:
ValueError: If required columns are missing or data is invalid
"""
try:
required_cols = {"type", "length", "ONT", "Illumina"}
if not all(col in sv_data_df.columns for col in required_cols):
missing = required_cols - set(sv_data_df.columns)
raise ValueError(f"Missing required columns: {missing}")
# Configuration
sv_types_config = {
"INS": SVTypeConfig("Insertion"),
"DEL": SVTypeConfig("Deletion"),
"DUP": SVTypeConfig("Duplication"),
"INV": SVTypeConfig("Inversion"),
"BND": SVTypeConfig("Breakend", use_log_scale=False),
"STR": SVTypeConfig("Short Tandem Repeat", use_log_scale=False),
}
platform_config = {
"ONT": ("long-read", 0.7),
"Illumina": ("short-read", 0.7),
}
# Create main figure and GridSpec
fig = plt.figure(figsize=figsize, dpi=dpi)
gs = fig.add_gridspec(3, 2)
# Create subplots
for idx, (sv_type, config) in enumerate(sv_types_config.items()):
ax = fig.add_subplot(gs[idx // 2, idx % 2])
for platform, (label, alpha) in platform_config.items():
sv_data = sv_data_df.filter(
(pl.col("type") == sv_type) & pl.col(platform)
)
if not sv_data.is_empty():
sns.histplot(
data=sv_data,
x="length",
kde=True,
log_scale=config.use_log_scale,
stat="density",
ax=ax,
alpha=alpha,
label=label,
)
ax.set_title(f"{config.name} Size Distribution")
ax.set_xlabel(f"SV Size (bp)")
ax.set_ylabel("Density")
# Add panel labels
ax.text(
-0.1,
1.05,
chr(65 + idx), # A, B, C, D, E, F
transform=ax.transAxes,
fontsize=12,
fontweight="bold",
va="top",
)
if idx == 0:
legend = ax.legend(title="Technology")
legend.get_title().set_weight("bold")
fig.set_constrained_layout(True)
return fig
except Exception as e:
logger.error(f"Error plotting SV size distribution: {str(e)}")
raise
sv_size_dist_plot = plot_sv_size_distribution_by_type(sv_data_df)
4. SV Chromosomal Distribution¶
In [50]:
@dataclass
class ChromosomeData:
"""Data class for chromosome lengths."""
lengths: Dict[str, int] = field(
default_factory=lambda: {
"chr1": 248956422,
"chr2": 242193529,
"chr3": 198295559,
"chr4": 190214555,
"chr5": 181538259,
"chr6": 170805979,
"chr7": 159345973,
"chr8": 145138636,
"chr9": 138394717,
"chr10": 133797422,
"chr11": 135086622,
"chr12": 133275309,
"chr13": 114364328,
"chr14": 107043718,
"chr15": 101991189,
"chr16": 90338345,
"chr17": 83257441,
"chr18": 80373285,
"chr19": 58617616,
"chr20": 64444167,
"chr21": 46709983,
"chr22": 50818468,
"chrX": 156040895,
"chrY": 57227415,
}
)
@property
def ordered_chroms(self) -> list[str]:
"""Return chromosomes in proper order."""
return [f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"]
def normalize_by_chrom_length(
chrom_distribution_df: pl.DataFrame, chrom_data: ChromosomeData
) -> pl.DataFrame:
"""
Normalize SV counts by chromosome length.
Args:
chrom_distribution_df: DataFrame with chromosome distribution
chrom_data: ChromosomeData instance with chromosome lengths
Returns:
Normalized DataFrame
Raises:
ValueError: If chromosome data is missing
"""
try:
# Create a dictionary for chromosome lengths in millions of base pairs
chrom_lengths_mb = {k: v / 1e6 for k, v in chrom_data.lengths.items()}
# Convert to LazyFrame for more efficient operations
normalized = chrom_distribution_df.lazy()
# Create normalized columns using replace_strict instead of map_elements
normalized = normalized.with_columns(
[
(
pl.col("ont") / pl.col("chrom").replace_strict(chrom_lengths_mb)
).alias("ont"),
(
pl.col("illumina")
/ pl.col("chrom").replace_strict(chrom_lengths_mb)
).alias("illumina"),
]
)
return normalized.collect()
except Exception as e:
logger.error(f"Error normalizing chromosome distribution: {str(e)}")
raise
def analyze_chrom_distribution(sv_data_df: pl.DataFrame) -> pl.DataFrame:
"""
Analyze the distribution of structural variants across chromosomes for ONT and Illumina platforms.
Args:
sv_data_df: DataFrame containing SV data with ONT and Illumina boolean columns
and a 'chrom' column
Returns:
DataFrame with chromosome distribution counts for both platforms
Raises:
ValueError: If required columns are missing
"""
try:
# Define valid chromosomes
valid_chroms = [f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"]
# Create counts for ONT data
ont_counts = (
sv_data_df.filter(pl.col("ONT") == True)
.group_by("chrom")
.len()
.with_columns(pl.col("len").alias("ont"))
.drop("len")
)
# Create counts for Illumina data
illumina_counts = (
sv_data_df.filter(pl.col("Illumina") == True)
.group_by("chrom")
.len()
.with_columns(pl.col("len").alias("illumina"))
.drop("len")
)
# Join the counts and fill missing values with 0
result = (
pl.DataFrame({"chrom": valid_chroms})
.join(ont_counts, on="chrom", how="left")
.join(illumina_counts, on="chrom", how="left")
.with_columns([pl.col("ont").fill_null(0), pl.col("illumina").fill_null(0)])
)
return result
except Exception as e:
logger.error(f"Error analyzing chromosome distribution: {str(e)}")
raise
def plot_chrom_distribution(
chrom_distribution_df: pl.DataFrame,
figsize: Tuple[int, int] = (12, 6),
dpi: int = 300,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Plot chromosome distribution of structural variants.
Args:
chrom_distribution_df: DataFrame with chromosome distribution
figsize: Figure size as (width, height)
dpi: Figure resolution
gs: Optional GridSpec for plotting within a larger figure
Returns:
Figure object if created independently (gs=None)
Raises:
ValueError: If required data is missing
"""
try:
chrom_data = ChromosomeData()
normalized_df = normalize_by_chrom_length(chrom_distribution_df, chrom_data)
total_svs = normalized_df.select(
[pl.col("ont").sum(), pl.col("illumina").sum()]
)
normalized_df = normalized_df.with_columns(
[
(pl.col("ont") / total_svs.item(0, 0) * 100).alias("long-read"),
(pl.col("illumina") / total_svs.item(0, 1) * 100).alias("short-read"),
]
).drop(["ont", "illumina"])
plot_data = normalized_df.unpivot(
index=["chrom"],
on=["long-read", "short-read"],
variable_name="Platform",
value_name="value",
)
plot_data = (
plot_data.with_columns(pl.col("chrom").cast(pl.Categorical))
.with_columns(
pl.col("chrom")
.cast(pl.Categorical)
.map_elements(
lambda x: (
chrom_data.ordered_chroms.index(x)
if x in chrom_data.ordered_chroms
else -1
),
return_dtype=pl.Int64,
)
.alias("chrom_order")
)
.sort("chrom_order")
.drop("chrom_order")
)
if gs is None:
fig = plt.figure(figsize=figsize, dpi=dpi)
ax = fig.add_subplot(111)
else:
fig = plt.gcf()
ax = fig.add_subplot(gs)
sns.barplot(
data=plot_data,
x="chrom",
y="value",
hue="Platform",
order=chrom_data.ordered_chroms,
ax=ax,
)
ax.set_title("Normalised Chromosomal Distribution of SVs")
ax.set_xlabel("Chromosome")
ax.set_ylabel("Proportion of SVs per Mb (%)")
ax.legend(title="Technology").get_title().set_weight("bold")
plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
locs, labels = ax.get_xticks(), ax.get_xticklabels()
ax.set_xticks([loc + 0.18 for loc in locs])
if gs is None:
plt.tight_layout()
return fig
return None
except Exception as e:
logger.error(f"Error plotting chromosome distribution: {str(e)}")
raise
chrom_distribution_df = analyze_chrom_distribution(sv_data_df)
sv_chrom_dist_plot = plot_chrom_distribution(chrom_distribution_df)
In [51]:
def calculate_sv_correlations(
chrom_distribution_df: pl.DataFrame, chromosome_data: ChromosomeData
) -> Tuple[pl.DataFrame, float, float]:
"""
Calculate Pearson correlations between chromosome length and mean SV counts.
Args:
chrom_distribution_df: Polars DataFrame with SV counts per chromosome.
chromosome_data: ChromosomeData instance containing chromosome lengths.
Returns:
Tuple containing:
- Polars DataFrame with correlation data.
- Pearson correlation coefficient for ONT.
- Pearson correlation coefficient for Illumina.
"""
# Use ordered chromosomes from ChromosomeData
ordered_chroms = chromosome_data.ordered_chroms
# Calculate mean counts for each chromosome with consistent ordering
ont_mean_count = (
chrom_distribution_df.group_by("chrom")
.agg(pl.col("ont").mean().alias("ont_mean"))
.filter(pl.col("chrom").is_in(ordered_chroms))
.sort("chrom", descending=False)
)
illumina_mean_count = (
chrom_distribution_df.group_by("chrom")
.agg(pl.col("illumina").mean().alias("illumina_mean"))
.filter(pl.col("chrom").is_in(ordered_chroms))
.sort("chrom", descending=False)
)
# Create ordered lists of chromosome lengths
ordered_lengths = [chromosome_data.lengths[chrom] for chrom in ordered_chroms]
# Create correlation DataFrame with consistent ordering
corr_data = pl.DataFrame(
{
"chrom": ordered_chroms,
"length": ordered_lengths,
"ont_count": ont_mean_count["ont_mean"].to_list(),
"illumina_count": illumina_mean_count["illumina_mean"].to_list(),
}
)
# Calculate correlations using ordered data
ont_corr, ont_p = stats.pearsonr(corr_data["length"], corr_data["ont_count"])
illumina_corr, illumina_p = stats.pearsonr(
corr_data["length"], corr_data["illumina_count"]
)
print(f"ONT correlation: {ont_corr:.2f} (p-value: {ont_p:.2e})")
print(f"Illumina correlation: {illumina_corr:.2f} (p-value: {illumina_p:.2e})")
return corr_data, ont_corr, illumina_corr
def plot_sv_correlations(
corr_data: pl.DataFrame,
ont_corr: float,
illumina_corr: float,
figsize: Tuple[int, int] = (12, 5),
dpi: int = 300,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Plot correlations between chromosome length and SV counts for both technologies.
Args:
corr_data: Polars DataFrame containing correlation data.
ont_corr: Pearson correlation coefficient for ONT data.
illumina_corr: Pearson correlation coefficient for Illumina data.
figsize: Figure size as (width, height).
dpi: Figure resolution.
gs: Optional GridSpec for plotting within a larger figure.
Returns:
Figure object if created independently (gs=None).
Raises:
ValueError: If required correlation data is missing.
"""
try:
if gs is None:
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
else:
fig = plt.gcf()
ax = fig.add_subplot(gs)
_, ont_p = stats.pearsonr(corr_data["length"], corr_data["ont_count"])
_, illumina_p = stats.pearsonr(corr_data["length"], corr_data["illumina_count"])
# Plot long-read data
sns.regplot(
x="length",
y="ont_count",
data=corr_data,
ax=ax,
label=f"long-read (r={ont_corr:.2f}, p={ont_p:.2e})",
scatter_kws={"alpha": 0.8, "label": "long-read data points"},
line_kws={
"color": sns.color_palette()[0],
"label": "long-read regression line",
},
)
# Plot short-read data
sns.regplot(
x="length",
y="illumina_count",
data=corr_data,
ax=ax,
label=f"short-read (r={illumina_corr:.2f}, p={illumina_p:.2e})",
scatter_kws={"alpha": 0.8, "label": "short-read data points"},
line_kws={
"color": sns.color_palette()[1],
"label": "short-read regression line",
},
)
ax.set_title("Structural Variant Counts vs Chromosome Length")
ax.set_xlabel("Chromosome Length (bp)")
ax.set_ylabel("Number of Structural Variants")
# Create custom legend
handles, labels = ax.get_legend_handles_labels()
# Add line and CI band descriptions to legend
legend_elements = [
Line2D(
[0],
[0],
color=sns.color_palette()[0],
label=f"long-read (r={ont_corr:.2f}, p={ont_p:.2e})",
),
Line2D(
[0],
[0],
color=sns.color_palette()[1],
label=f"short-read (r={illumina_corr:.2f}, p={illumina_p:.2e})",
),
Patch(
facecolor=sns.color_palette()[0], alpha=0.2, label="long-read 95% CI"
),
Patch(
facecolor=sns.color_palette()[1], alpha=0.2, label="short-read 95% CI"
),
]
ax.legend(handles=legend_elements, title="Technology")
ax.get_legend().get_title().set_weight("bold")
if gs is None:
plt.tight_layout()
return fig
return None
except Exception as e:
logger.error(f"Error plotting SV correlations: {str(e)}")
raise
chromosome_data = ChromosomeData()
corr_data, ont_corr, illumina_corr = calculate_sv_correlations(
chrom_distribution_df, chromosome_data
)
sv_length_corr_plot = plot_sv_correlations(corr_data, ont_corr, illumina_corr)
ONT correlation: 0.11 (p-value: 6.16e-01) Illumina correlation: 0.05 (p-value: 8.24e-01)
Combined Plots¶
In [52]:
def create_combined_sv_chr_plot(
sv_data_df: pl.DataFrame,
figsize: Tuple[int, int] = (12, 4), # Modified for side-by-side layout
dpi: int = 300,
) -> plt.Figure:
"""
Create a combined figure showing chromosome distribution and correlation plots side by side.
Args:
sv_data_df: DataFrame containing SV data
figsize: Figure size as (width, height)
dpi: Figure resolution
Returns:
Combined figure object
"""
try:
fig = plt.figure(figsize=figsize, dpi=dpi)
gs = fig.add_gridspec(1, 2)
plot_chrom_distribution(chrom_distribution_df, gs=gs[0])
ax_left = plt.gcf().get_axes()[0]
ax_left.text(
-0.1,
1.05,
"A",
transform=ax_left.transAxes,
fontsize=12,
fontweight="bold",
va="top",
)
plot_sv_correlations(corr_data, ont_corr, illumina_corr, gs=gs[1])
ax_right = plt.gcf().get_axes()[1]
ax_right.text(
-0.1,
1.05,
"B",
transform=ax_right.transAxes,
fontsize=12,
fontweight="bold",
va="top",
)
fig.set_constrained_layout(True)
return fig
except Exception as e:
logger.error(f"Error creating combined SV analysis plot: {str(e)}")
raise
combined_sv_chr_plot = create_combined_sv_chr_plot(sv_data_df)
5. Impact of sequencing depth on structural variants¶
In [53]:
def _fit_and_plot_asymptotic_curve(
x: np.ndarray,
y: np.ndarray,
ax: plt.Axes,
color: str = "#1f77b4", # Default matplotlib blue
alpha: float = 0.2,
) -> None:
"""Fit and plot asymptotic curve with confidence intervals.
Args:
x: Input x values (depth)
y: Input y values (SV counts)
ax: Matplotlib axes object to plot on
color: Color for the curve and confidence interval. Defaults to matplotlib blue
alpha: Transparency for confidence interval. Defaults to 0.2
Raises:
RuntimeError: If curve fitting fails
"""
def asymptotic_func(x: np.ndarray, a: float, b: float, c: float) -> np.ndarray:
"""Asymptotic function for curve fitting."""
return a - b * np.exp(-c * x)
try:
# Fit curve
popt, pcov = curve_fit(
asymptotic_func,
x,
y,
p0=[np.max(y), np.max(y) - np.min(y), 0.1],
bounds=([0, 0, 0], [np.inf, np.inf, 1]),
)
# Generate points for smooth curve
x_range = np.linspace(x.min(), x.max(), 100)
y_fit = asymptotic_func(x_range, *popt)
ax.plot(
x_range,
y_fit,
color=color,
linestyle="-",
linewidth=2,
label="Line of best fit",
)
# Calculate and plot confidence intervals
perr = np.sqrt(np.diag(pcov))
n = len(x)
dof = max(0, n - len(popt))
t = stats.t.ppf(0.975, dof)
y_err = np.sqrt(np.sum((y - asymptotic_func(x, *popt)) ** 2) / dof)
ci = (
t
* y_err
* np.sqrt(
1 / n + (x_range - np.mean(x)) ** 2 / np.sum((x - np.mean(x)) ** 2)
)
)
ax.fill_between(
x_range,
y_fit - ci,
y_fit + ci,
color=color,
alpha=alpha,
label="95% Confidence Interval",
)
except RuntimeError as e:
logger.warning(f"Failed to fit asymptotic curve: {str(e)}")
raise
def _plot_sv_depth_correlation(
data: pl.DataFrame, metric: str, label: str, ax: plt.Axes, color: Any
) -> None:
"""Plot correlation between structural variant counts and sequencing depth.
Args:
data: Polars DataFrame containing depth and SV count data
metric: Column name for SV counts (e.g., 'ONT', 'Illumina', 'Merged')
label: Label for the plot title and legend
ax: Matplotlib axes object to plot on
color: Color for the scatter plot and fitted curve
Returns:
None
Raises:
RuntimeError: If asymptotic curve fitting fails
"""
x = data.get_column("wg_mean_depth").to_numpy()
y = data.get_column(metric).to_numpy()
sns.scatterplot(x=x, y=y, ax=ax, color=color)
try:
_fit_and_plot_asymptotic_curve(x, y, ax, color=color)
except RuntimeError:
logger.warning(f"Could not fit asymptotic curve for {label}")
r_value, p_value = stats.pearsonr(x, y)
ax.set_title(f"{label}\nr = {r_value:.2f}, p = {p_value:.2e}")
ax.set_xlabel("Whole Genome Mean Depth")
ax.set_ylabel("Number of SV Calls")
def plot_depth_vs_sv_performance(
wg_depth_df: pl.DataFrame,
sv_data_df: pl.DataFrame,
figsize: Tuple[int, int] = (12, 4),
dpi: int = 300,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""Plot relationship between sequencing depth and SV detection performance.
Args:
wg_depth_df: DataFrame containing whole genome depth statistics
sv_data_df: DataFrame containing structural variant data
figsize: Figure dimensions. Defaults to (14, 4)
dpi: Figure resolution. Defaults to 300
gs: GridSpec for subplot placement. If None, creates standalone figure
Returns:
Optional[plt.Figure]: If gs is None, returns the figure.
If gs is provided, returns None (plots are added as subfigures).
Raises:
ValueError: If required columns are missing from input DataFrames
"""
try:
depth_data = (
wg_depth_df.filter(pl.col("chrom") == "chr1")
.select(["sample", "mean"])
.rename({"mean": "wg_mean_depth"})
)
sv_counts = sv_data_df.group_by("sample_id").agg(
[pl.col("ONT").sum(), pl.col("Merged").sum()]
)
analysis_data = sv_counts.join(
depth_data, left_on="sample_id", right_on="sample"
)
if gs is None:
fig = plt.figure(figsize=figsize, dpi=dpi)
gs = gridspec.GridSpec(1, 2, figure=fig)
else:
fig = plt.gcf()
metrics = ["ONT", "Merged"]
labels = ["Long-read", "Consensus"]
palette = sns.color_palette()
for i, (metric, label) in enumerate(zip(metrics, labels)):
ax = plt.subplot(gs[0, i])
_plot_sv_depth_correlation(
analysis_data, metric, label, ax, color=palette[i]
)
if i == 0: # Only add legend to the first plot
ax.legend()
if gs is None:
plt.tight_layout()
return fig
return None
except Exception as e:
logger.error(f"Error plotting depth vs SV performance: {str(e)}")
raise
sv_performance_plot = plot_depth_vs_sv_performance(wg_depth_df, sv_data_df)
In [54]:
def prepare_sv_depth_data(
total_depth_df: pl.DataFrame,
sv_data_df: pl.DataFrame,
) -> Tuple[pl.DataFrame, List[str]]:
"""
Prepare and join SV and depth data for analysis.
Args:
total_depth_df: DataFrame containing total depth statistics
sv_data_df: DataFrame containing structural variant data
Returns:
Tuple containing:
- Joined DataFrame with SV and depth information
- List of unique SV types
Raises:
ValueError: If required columns are missing
"""
try:
analysis_data = sv_data_df.join(
total_depth_df.select(["sample", "mean_depth"]),
left_on="sample_id",
right_on="sample",
)
sv_types = sorted(analysis_data.get_column("type").unique().to_list())
return analysis_data, sv_types
except Exception as e:
logger.error(f"Error preparing SV depth data: {str(e)}")
raise
def _plot_sv_size_distribution(
platform_data: pl.DataFrame,
sv_types: List[str],
label: str,
ax: plt.Axes,
gs: gridspec.GridSpec,
) -> None:
"""
Helper function to plot SV size distribution for a specific platform.
Args:
platform_data: DataFrame filtered for specific platform
sv_types: List of unique SV types
label: Platform label (Long-read/Short-read)
ax: Matplotlib axes object
gs: GridSpec for subplot placement
"""
x = platform_data.get_column("mean_depth").to_numpy()
y = np.log10(platform_data.get_column("length").to_numpy())
r, p = stats.pearsonr(x, y)
# Add correlation stats below title
ax.set_title(f"{label} SV Size vs Depth\nr = {r:.3f}, p = {p:.2e}")
palette = dict(zip(sv_types, sns.color_palette(n_colors=len(sv_types))))
# Plot regression line with CI
reg_plot = sns.regplot(
data=platform_data,
x="mean_depth",
y="length",
scatter=False,
ax=ax,
color="grey",
line_kws={"linestyle": "-", "alpha": 0.8, "label": "Line of best fit"},
ci=95,
)
# Get the CI lines for legend
ci_lines = [line for line in ax.lines if line != reg_plot.lines[0]]
if ci_lines:
ci_lines[0].set_label("95% CI")
# Then plot the scatter points on top
sns.scatterplot(
data=platform_data,
x="mean_depth",
y="length",
hue="type",
hue_order=sv_types,
palette=palette,
ax=ax,
alpha=0.6,
)
ax.set_yscale("log")
ax.set_xlabel("Whole Genome Mean Depth")
ax.set_ylabel("SV Size (bp)")
def plot_sv_size_vs_depth(
analysis_data: pl.DataFrame,
sv_types: List[str],
figsize: Tuple[int, int] = (12, 4),
dpi: int = 300,
gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
"""
Plot relationship between SV sizes and sequencing depth.
Args:
analysis_data: Prepared DataFrame containing joined SV and depth data
sv_types: List of unique SV types
figsize: Figure dimensions. Defaults to (8, 4)
dpi: Figure resolution. Defaults to 300
gs: GridSpec for subplot placement. If None, creates standalone figure
Returns:
Optional[plt.Figure]: If gs is None, returns the figure.
If gs is provided, returns None (plots are added as subfigures).
"""
try:
if gs is None:
fig = plt.figure(figsize=figsize, dpi=dpi)
gs = gridspec.GridSpec(1, 1, figure=fig)
else:
fig = plt.gcf()
# Removed Illumina from platforms list
platforms = [("ONT", "Long-read")]
for idx, (platform, label) in enumerate(platforms):
ax = plt.subplot(gs[0, idx])
_plot_sv_size_distribution(
analysis_data.filter(pl.col(platform)),
sv_types,
label,
ax,
gs,
)
from matplotlib.lines import Line2D
scatter_handles, scatter_labels = ax.get_legend_handles_labels()
line_of_best_fit = Line2D(
[],
[],
color="grey",
linestyle="-",
alpha=0.8,
label="Line of best fit",
)
ci_patch = Patch(color="grey", alpha=0.2, label="95% CI")
handles = scatter_handles + [line_of_best_fit, ci_patch]
labels = scatter_labels + ["Line of best fit", "95% CI"]
legend = ax.legend(
handles,
labels,
title="SV Type",
loc="upper left",
bbox_to_anchor=(1.05, 1),
)
legend.get_title().set_weight("bold")
if gs is None:
plt.tight_layout()
return fig
return None
except Exception as e:
logger.error(f"Error plotting SV size vs depth: {str(e)}")
raise
sv_depth_data, sv_types = prepare_sv_depth_data(total_depth_df, sv_data_df)
size_depth_plot = plot_sv_size_vs_depth(
sv_depth_data,
sv_types,
)
Combined Plots¶
In [55]:
def create_combined_sv_depth_plot(
wg_depth_df: pl.DataFrame,
total_depth_df: pl.DataFrame,
sv_data_df: pl.DataFrame,
figsize: Tuple[int, int] = (12, 8),
dpi: int = 300,
) -> plt.Figure:
"""
Create a combined figure showing SV calls vs depth and SV size vs depth analyses.
Args:
wg_depth_df: DataFrame containing whole genome depth statistics
total_depth_df: DataFrame containing total depth statistics
sv_data_df: DataFrame containing structural variant data
figsize: Figure size as (width, height)
dpi: Figure resolution
Returns:
Combined figure object
Raises:
ValueError: If required columns are missing from input DataFrames
"""
try:
fig = plt.figure(figsize=figsize, dpi=dpi)
gs = fig.add_gridspec(3, 1, height_ratios=[1, 0.1, 1])
gs_top = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs[0])
gs_bottom = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[2])
# Plot SV calls vs depth in top row
plot_depth_vs_sv_performance(wg_depth_df, sv_data_df, gs=gs_top)
# Plot SV size vs depth in bottom row
plot_sv_size_vs_depth(sv_depth_data, sv_types, gs=gs_bottom)
axes = plt.gcf().get_axes()
fig.text(
0.5,
1.01,
"SV Calls vs Whole Genome Mean Depth",
ha="center",
va="center",
)
for idx, ax in enumerate(axes[:2]):
ax.text(
-0.1,
1.05,
chr(65 + idx), # A, B
transform=ax.transAxes,
fontsize=12,
fontweight="bold",
va="top",
)
for idx, ax in enumerate(axes[2:]):
ax.text(
-0.1,
1.05,
chr(67 + idx), # C
transform=ax.transAxes,
fontsize=12,
fontweight="bold",
va="top",
)
fig.set_constrained_layout(True)
return fig
except Exception as e:
logger.error(f"Error creating combined SV analysis plot: {str(e)}")
raise
combined_sv_analysis_plot = create_combined_sv_depth_plot(
wg_depth_df, total_depth_df, sv_data_df
)